activation_op.cc 28.9 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
T
tink2123 已提交
16
#include <memory>
D
dzhwinter 已提交
17
#include <string>
18
#include <type_traits>
T
tink2123 已提交
19
#include <unordered_map>
20
#include <vector>
21
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
D
dzhwinter 已提交
22
#include "paddle/fluid/platform/port.h"
23 24 25
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
Q
qijun 已提交
26

A
Adam 已提交
27 28
DECLARE_bool(use_mkldnn);

Q
qijun 已提交
29 30 31
namespace paddle {
namespace operators {

32 33
using paddle::framework::Tensor;

34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
  return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}

std::unique_ptr<std::unordered_set<std::string>> GetInplaceOpSet() {
  std::unique_ptr<std::unordered_set<std::string>> ret(
      new std::unordered_set<std::string>());
#define INSERT_INTO_INPLACE_OP_SET(op_type, __omitted, fwd_functor, \
                                   bwd_functor)                     \
  if (CanInplaceAct<bwd_functor<float>>()) {                        \
    ret->insert(#op_type);                                          \
  }

  FOR_EACH_ACTIVATION_OP(INSERT_INTO_INPLACE_OP_SET);
#undef INSERT_INTO_INPLACE_OP_SET
  return ret;
}

53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)                    \
  class OP_NAME##OpMaker                                                     \
      : public ::paddle::framework::OpProtoAndCheckerMaker {                 \
   public:                                                                   \
    void Make() override {                                                   \
      AddInput("X", "Input of " #OP_NAME " operator");                       \
      AddOutput("Out", "Output of " #OP_NAME " operator");                   \
      AddAttr<bool>("use_mkldnn",                                            \
                    "(bool, default false) Only used in mkldnn kernel")      \
          .SetDefault(false);                                                \
      AddAttr<bool>("use_cudnn",                                             \
                    "(bool, default false) Only used in cudnn kernel, need " \
                    "install cudnn")                                         \
          .SetDefault(false);                                                \
      AddAttr<bool>(                                                         \
          "is_test",                                                         \
          "(bool, default false) Set to true for inference only, false "     \
          "for training. Some layers may run faster when this is true.")     \
          .SetDefault(false);                                                \
      AddComment(OP_COMMENT);                                                \
    }                                                                        \
D
dzhwinter 已提交
74
  }
D
dzhwinter 已提交
75

76 77 78 79 80 81 82 83 84 85 86 87 88
template <ActBwdOpFwdDeps kDepValue>
class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType(ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetAttrMap(Attrs());

A
Adam 已提交
89 90 91 92
    if ((static_cast<int>(kDepValue) &
         static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
        FLAGS_use_mkldnn || (op->HasAttr("use_mkldnn") &&
                             boost::get<bool>(op->GetAttr("use_mkldnn")))) {
93 94 95 96 97 98 99 100 101
      op->SetInput("X", Input("X"));
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
      op->SetInput("Out", Output("Out"));
    }

    return op;
D
dzhwinter 已提交
102
  }
103
};
D
dzhwinter 已提交
104

105 106 107 108
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
109
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
110 111 112 113 114 115 116 117 118 119
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
// and re-enable this in the future
// #ifdef PADDLE_WITH_CUDA
//   auto it1 = oper.Attrs().find("use_cudnn");
//   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
//     library = framework::LibraryType::kCUDNN;
//   }
// #endif
120 121 122 123 124
#ifdef PADDLE_WITH_MKLDNN
  auto it = oper.Attrs().find("use_mkldnn");
  if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
125
    layout = framework::DataLayout::kMKLDNN;
126 127 128
  }
#endif
  return framework::OpKernelType(
C
chengduo 已提交
129 130
      framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout,
      library);
131 132
}

Q
qijun 已提交
133 134 135 136
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

137
  void InferShape(framework::InferShapeContext* ctx) const override {
138
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
139
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
140
  }
141

142
 protected:
143 144 145 146
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
147 148
};

C
chengduo 已提交
149 150 151 152 153 154
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
  std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
      const override {
    return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
155 156 157
  }
};

Q
qijun 已提交
158 159 160 161
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

162
  void InferShape(framework::InferShapeContext* ctx) const override {
163 164 165
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
166
  }
167

168
 protected:
169 170
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
171
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
172
  }
Q
qijun 已提交
173 174
};

D
dzhwinter 已提交
175
UNUSED constexpr char SigmoidDoc[] = R"DOC(
176
Sigmoid Activation Operator
K
Kexin Zhao 已提交
177

178
$$out = \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
179

D
dzhwinter 已提交
180
)DOC";
Q
qijun 已提交
181

D
dzhwinter 已提交
182
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
183
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
184

185
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
186

D
dzhwinter 已提交
187
)DOC";
188

D
dzhwinter 已提交
189
UNUSED constexpr char ExpDoc[] = R"DOC(
K
kexinzhao 已提交
190
Exp Activation Operator.
K
Kexin Zhao 已提交
191

F
fengjiayi 已提交
192
$out = e^x$
K
Kexin Zhao 已提交
193

D
dzhwinter 已提交
194
)DOC";
Q
qijun 已提交
195

D
dzhwinter 已提交
196
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
197
Relu Activation Operator.
K
Kexin Zhao 已提交
198

F
fengjiayi 已提交
199
$out = \max(x, 0)$
K
Kexin Zhao 已提交
200

D
dzhwinter 已提交
201
)DOC";
K
Kexin Zhao 已提交
202

C
Clementine 已提交
203 204 205 206 207 208 209
UNUSED constexpr char GeluDoc[] = R"DOC(
Gelu Activation Operator.

$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$

)DOC";

D
dzhwinter 已提交
210
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
211
Tanh Activation Operator.
K
Kexin Zhao 已提交
212

Q
update  
qiaolongfei 已提交
213
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
214

D
dzhwinter 已提交
215
)DOC";
216

D
dzhwinter 已提交
217
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
218
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
219

Y
Yan Chunwei 已提交
220
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
221

D
dzhwinter 已提交
222
)DOC";
K
Kexin Zhao 已提交
223

D
dzhwinter 已提交
224
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
225
Sqrt Activation Operator.
K
Kexin Zhao 已提交
226

227 228 229
Please make sure legal input, when input a negative value closed to zero,
you should add a small epsilon(1e-12) to avoid negative number caused by numerical errors.

F
fengjiayi 已提交
230
$out = \sqrt{x}$
K
Kexin Zhao 已提交
231

D
dzhwinter 已提交
232
)DOC";
233

Z
zhoukunsheng 已提交
234 235 236 237 238 239 240 241 242
UNUSED constexpr char RsqrtDoc[] = R"DOC(
Rsqrt Activation Operator.

Please make sure input is legal in case of numeric errors.

$out = \frac{1}{\sqrt{x}}$

)DOC";

D
dzhwinter 已提交
243
UNUSED constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
244
Abs Activation Operator.
K
Kexin Zhao 已提交
245

F
fengjiayi 已提交
246
$out = |x|$
K
Kexin Zhao 已提交
247

D
dzhwinter 已提交
248
)DOC";
249

D
dzhwinter 已提交
250
UNUSED constexpr char CeilDoc[] = R"DOC(
D
dzhwinter 已提交
251 252
Ceil Activation Operator.

253
$out = \left \lceil x \right \rceil$
D
dzhwinter 已提交
254

D
dzhwinter 已提交
255
)DOC";
D
dzhwinter 已提交
256

D
dzhwinter 已提交
257
UNUSED constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
258 259
Floor Activation Operator.

260
$out = \left \lfloor x \right \rfloor$
D
dzhwinter 已提交
261

D
dzhwinter 已提交
262
)DOC";
D
dzhwinter 已提交
263

D
dzhwinter 已提交
264
UNUSED constexpr char CosDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
265
Cosine Activation Operator.
C
add cos  
chengduoZH 已提交
266 267 268

$out = cos(x)$

D
dzhwinter 已提交
269
)DOC";
C
add cos  
chengduoZH 已提交
270

D
dzhwinter 已提交
271
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
272 273 274 275
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
276
)DOC";
C
add sin  
chengduoZH 已提交
277

D
dzhwinter 已提交
278
UNUSED constexpr char RoundDoc[] = R"DOC(
D
dzhwinter 已提交
279 280
Round Activation Operator.

F
fengjiayi 已提交
281
$out = [x]$
D
dzhwinter 已提交
282

D
dzhwinter 已提交
283
)DOC";
D
dzhwinter 已提交
284

D
dzhwinter 已提交
285
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
286
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
287

288
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
289

D
dzhwinter 已提交
290
)DOC";
291

D
dzhwinter 已提交
292
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
293
Log Activation Operator.
K
Kexin Zhao 已提交
294

F
fengjiayi 已提交
295
$out = \ln(x)$
K
Kexin Zhao 已提交
296 297 298

Natural logarithm of x.

D
dzhwinter 已提交
299 300
)DOC";

D
dzhwinter 已提交
301
UNUSED constexpr char SquareDoc[] = R"DOC(
D
dzhwinter 已提交
302 303 304
Square Activation Operator.

$out = x^2$
305

D
dzhwinter 已提交
306 307
)DOC";

D
dzhwinter 已提交
308
UNUSED constexpr char SoftplusDoc[] = R"DOC(
D
dzhwinter 已提交
309 310 311 312 313 314
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

D
dzhwinter 已提交
315
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
316 317
Softsign Activation Operator.

318
$$out = \\frac{x}{1 + \|x\|}$$
D
dzhwinter 已提交
319 320 321

)DOC";

T
tink2123 已提交
322 323 324 325 326 327
class AcosOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of acos operator");
    AddOutput("Out", "Output of acos operator");
    AddComment(R"DOC(
328 329
Arccosine Activation Operator.

T
tink2123 已提交
330
$$out = \cos^{-1}(x)$$
331

T
tink2123 已提交
332 333 334
)DOC");
  }
};
335

T
tink2123 已提交
336 337 338 339 340 341
class AsinOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of asin operator");
    AddOutput("Out", "Output of asin operator");
    AddComment(R"DOC(
342 343
Arcsine Activation Operator.

T
tink2123 已提交
344
$$out = \sin^{-1}(x)$$
345

T
tink2123 已提交
346 347 348
)DOC");
  }
};
349

T
tink2123 已提交
350 351 352 353 354 355
class AtanOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of atan operator");
    AddOutput("Out", "Output of atan operator");
    AddComment(R"DOC(
356 357
Arctanh Activation Operator.

T
tink2123 已提交
358
$$out = \tanh^{-1}(x)$$
359

T
tink2123 已提交
360 361 362
)DOC");
  }
};
363

D
dzhwinter 已提交
364
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
365
 public:
Y
Yu Yang 已提交
366
  void Make() override {
D
dzhwinter 已提交
367 368 369
    AddInput("X", "Input of LeakyRelu operator");
    AddOutput("Out", "Output of LeakyRelu operator");
    AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
A
Adam 已提交
370 371 372 373 374 375 376
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
K
Kexin Zhao 已提交
377
    AddComment(R"DOC(
D
dzhwinter 已提交
378
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
379

D
dzhwinter 已提交
380
$out = \max(x, \alpha * x)$
K
Kexin Zhao 已提交
381 382

)DOC");
383 384 385
  }
};

D
dzhwinter 已提交
386
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
387
 public:
Y
Yu Yang 已提交
388
  void Make() override {
D
dzhwinter 已提交
389 390 391
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
392
    AddComment(R"DOC(
393 394 395
:strong:`Softshrink Activation Operator`

..  math::
396
    out = \begin{cases}
397 398 399 400
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
401 402

)DOC");
K
kexinzhao 已提交
403 404 405
  }
};

D
dzhwinter 已提交
406
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
407
 public:
Y
Yu Yang 已提交
408
  void Make() override {
D
dzhwinter 已提交
409 410
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
411 412
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
413
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
414
    AddComment(R"DOC(
Y
yuyang18 已提交
415
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
416

Y
yuyang18 已提交
417 418 419 420 421 422
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
423 424

)DOC");
425 426 427
  }
};

428 429
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
430
  void Make() override {
431
    AddInput("X", "Input of BRelu operator");
F
fengjiayi 已提交
432
    AddOutput("Out", "Output of BRelu operator");
433 434 435 436
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
437
    AddComment(R"DOC(
K
kexinzhao 已提交
438
BRelu Activation Operator.
K
Kexin Zhao 已提交
439

F
fengjiayi 已提交
440
$out = \max(\min(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
441 442

)DOC");
443 444 445 446 447
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
448
  void Make() override {
449
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
450
    AddOutput("Out", "Output of SoftRelu operator");
451 452
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
453
    AddComment(R"DOC(
K
kexinzhao 已提交
454
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
455

T
tensor-tang 已提交
456
$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$
K
Kexin Zhao 已提交
457 458

)DOC");
459 460 461
  }
};

462 463
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
464
  void Make() override {
K
Kexin Zhao 已提交
465
    AddInput("X", "Input of ELU operator");
F
fengjiayi 已提交
466
    AddOutput("Out", "Output of ELU operator");
467
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
468
    AddComment(R"DOC(
K
kexinzhao 已提交
469
ELU Activation Operator.
K
Kexin Zhao 已提交
470 471 472 473

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
474
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
475 476

)DOC");
477 478 479
  }
};

480 481
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
482
  void Make() override {
483
    AddInput("X", "Input of Relu6 operator");
F
fengjiayi 已提交
484
    AddOutput("Out", "Output of Relu6 operator");
485 486
    AddAttr<float>("threshold", "The threshold value of Relu6")
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
487
    AddComment(R"DOC(
K
kexinzhao 已提交
488
Relu6 Activation Operator.
K
Kexin Zhao 已提交
489

F
fengjiayi 已提交
490
$out = \min(\max(0, x), 6)$
K
Kexin Zhao 已提交
491 492

)DOC");
493 494 495
  }
};

496 497
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
498
  void Make() override {
499
    AddInput("X", "Input of Pow operator");
F
fengjiayi 已提交
500
    AddOutput("Out", "Output of Pow operator");
501
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
502
    AddComment(R"DOC(
K
kexinzhao 已提交
503
Pow Activation Operator.
K
Kexin Zhao 已提交
504

F
fengjiayi 已提交
505
$out = x^{factor}$
K
Kexin Zhao 已提交
506 507

)DOC");
508 509 510 511 512
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
513
  void Make() override {
514
    AddInput("X", "Input of STanh operator");
F
fengjiayi 已提交
515
    AddOutput("Out", "Output of STanh operator");
516 517 518 519
    AddAttr<float>("scale_a", "The scale parameter of a for the input")
        .SetDefault(2.0f / 3.0f);
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
520
    AddComment(R"DOC(
K
kexinzhao 已提交
521
STanh Activation Operator.
K
Kexin Zhao 已提交
522

Y
Yan Chunwei 已提交
523
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
524 525

)DOC");
Q
qijun 已提交
526 527 528
  }
};

529 530
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
531
  void Make() override {
532
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
533
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
534 535
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
536
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
537
    AddComment(R"DOC(
Y
yuyang18 已提交
538
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
539

Y
yuyang18 已提交
540
..  math::
K
Kexin Zhao 已提交
541

Y
yuyang18 已提交
542
    out = \begin{cases}
Y
yuyang18 已提交
543
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
544 545
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
546
)DOC");
547 548 549
  }
};

550 551
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
552
  void Make() override {
553
    AddInput("X", "Input of HardSigmoid operator");
F
fengjiayi 已提交
554
    AddOutput("Out", "Output of HardSigmoid operator");
555 556 557 558
    AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
        .SetDefault(0.2f);
    AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
        .SetDefault(0.5f);
559
    AddComment(R"DOC(
K
kexinzhao 已提交
560
HardSigmoid Activation Operator.
561

562
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
K
Kexin Zhao 已提交
563
which is much faster than sigmoid.
564

F
fengjiayi 已提交
565
$out = \max(0, \min(1, slope * x + shift))$
566 567

The slope should be positive. The offset can be either positive or negative.
K
Kexin Zhao 已提交
568
The default slope and shift are set according to the above reference.
569 570
It is recommended to use the defaults for this activation.

K
Kexin Zhao 已提交
571
)DOC");
572 573 574
  }
};

A
Abhinav Arora 已提交
575 576
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
577
  void Make() override {
A
Abhinav Arora 已提交
578
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
579
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
580 581 582 583
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

584
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
A
Abhinav Arora 已提交
585 586 587 588 589

)DOC");
  }
};

D
dzhwinter 已提交
590 591 592 593
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
C
Clementine 已提交
594
REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc);
D
dzhwinter 已提交
595 596 597
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
Z
zhoukunsheng 已提交
598
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
D
dzhwinter 已提交
599 600 601 602 603 604 605 606 607 608 609 610
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

611
template <ActBwdOpFwdDeps kDepValue>
612 613 614 615 616
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
617
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
618
      if (ctx->HasOutput("DX")) {
619 620 621
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
622
      if (ctx->HasOutput("DDOut")) {
623 624 625
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
626
    }
627
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
628
      if (ctx->HasOutput("DOut")) {
629 630 631
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
      if (ctx->HasOutput("DDOut")) {
660 661 662
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
663 664 665 666 667 668 669 670 671 672
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

673 674 675 676 677 678 679 680 681 682 683 684 685 686
//
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
//
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("relu_grad_grad");
    // input1: Out
    op->SetInput("Out", Input("Out"));
Q
qingqing01 已提交
687
    // input2: ddx
688 689
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
690
    // output: ddy
691 692 693 694 695
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717
// leaky_relu Grad: dx=dy if y>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx
class LeakyReluDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("leaky_relu_grad_grad");
    // input1: X
    op->SetInput("X", Input("X"));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

L
lvmengsi 已提交
718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("sqrt_grad_grad");
    op->SetInput("Out", Input("Out"));
    op->SetInput("DX", Output(framework::GradVarName("X")));
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    op->SetOutput("DOut", InputGrad("Out"));
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
class SquareDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("square_grad_grad");
    op->SetInput("X", Input("X"));
    // Out@GRAD: dy
    op->SetInput("DOut", Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));

    op->SetAttrMap(Attrs());

    // X@GRAD: dx
    op->SetOutput("DX", InputGrad("X"));
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

765 766 767 768 769 770 771 772
class ActivationGradOpInplaceInference : public framework::InplaceOpInference {
 public:
  std::unordered_map<std::string, std::string> operator()(
      const framework::OpDesc& op_desc, bool use_cuda) const override {
    return {{framework::GradVarName("Out"), framework::GradVarName("X")}};
  }
};

Q
qijun 已提交
773 774 775 776
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
777
namespace plat = paddle::platform;
778

779 780 781 782 783 784 785 786
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
      KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker,                \
      ops::ActivationOpInferVarType,                                        \
      ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>,  \
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
                       ::paddle::framework::SingleOpInplaceInToOut,         \
                       void>::type);                                        \
787 788
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad,              \
                    ops::ActivationGradOpInplaceInference);
789 790 791

#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor,        \
                                       grad_functor)                      \
Q
QI JUN 已提交
792 793 794 795 796 797 798 799 800 801
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
802
                                ops::grad_functor<double>>);
803

804 805
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
806

807
/* ==========================    relu register  ============================= */
808 809 810 811 812
REGISTER_OPERATOR(
    relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
813
                  ops::ActivationGradOpInplaceInference,
814
                  ops::ReluDoubleGradMaker);
815 816
REGISTER_OPERATOR(
    relu_grad_grad,
817
    ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>);
818 819 820 821 822 823 824 825 826 827 828

REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);

REGISTER_OP_CPU_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>);
829
/* ========================================================================== */
830

831
/* ======================== leaky relu register  ============================ */
832 833 834 835 836 837
REGISTER_OPERATOR(
    leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
838
                  ops::ActivationGradOpInplaceInference,
839
                  ops::LeakyReluDoubleGradMaker);
840 841
REGISTER_OPERATOR(
    leaky_relu_grad_grad,
842
    ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>);
843

844 845 846 847 848 849 850 851 852 853
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
                               LeakyReluGradFunctor);
REGISTER_OP_CPU_KERNEL(
    leaky_relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<
        plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
854 855
/* ========================================================================== */

L
lvmengsi 已提交
856 857 858 859 860 861
/* ===========================   sqrt register  ============================= */
REGISTER_OPERATOR(
    sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SqrtGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
862
                  ops::ActivationGradOpInplaceInference,
L
lvmengsi 已提交
863 864 865 866 867 868 869 870 871 872 873 874 875 876
                  ops::SqrtDoubleGradMaker);
REGISTER_OPERATOR(
    sqrt_grad_grad,
    ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
    sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

877 878 879 880 881 882 883
/* ==========================   square register  ============================ */
REGISTER_OPERATOR(
    square, ops::ActivationOp, ops::SquareOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SquareGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
884
                  ops::ActivationGradOpInplaceInference,
885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901
                  ops::SquareDoubleGradMaker);
REGISTER_OPERATOR(
    square_grad_grad,
    ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>);

REGISTER_ACTIVATION_CPU_KERNEL(square, Square, SquareFunctor,
                               SquareGradFunctor);

REGISTER_OP_CPU_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<plat::float16>>);
/* ========================================================================== */