activation_op.cc 28.8 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
T
tink2123 已提交
16
#include <memory>
D
dzhwinter 已提交
17
#include <string>
18
#include <type_traits>
T
tink2123 已提交
19
#include <unordered_map>
20
#include <vector>
21
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
D
dzhwinter 已提交
22
#include "paddle/fluid/platform/port.h"
23 24 25
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
Q
qijun 已提交
26 27 28 29

namespace paddle {
namespace operators {

30 31
using paddle::framework::Tensor;

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
  return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}

std::unique_ptr<std::unordered_set<std::string>> GetInplaceOpSet() {
  std::unique_ptr<std::unordered_set<std::string>> ret(
      new std::unordered_set<std::string>());
#define INSERT_INTO_INPLACE_OP_SET(op_type, __omitted, fwd_functor, \
                                   bwd_functor)                     \
  if (CanInplaceAct<bwd_functor<float>>()) {                        \
    ret->insert(#op_type);                                          \
  }

  FOR_EACH_ACTIVATION_OP(INSERT_INTO_INPLACE_OP_SET);
#undef INSERT_INTO_INPLACE_OP_SET
  return ret;
}

51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)                    \
  class OP_NAME##OpMaker                                                     \
      : public ::paddle::framework::OpProtoAndCheckerMaker {                 \
   public:                                                                   \
    void Make() override {                                                   \
      AddInput("X", "Input of " #OP_NAME " operator");                       \
      AddOutput("Out", "Output of " #OP_NAME " operator");                   \
      AddAttr<bool>("use_mkldnn",                                            \
                    "(bool, default false) Only used in mkldnn kernel")      \
          .SetDefault(false);                                                \
      AddAttr<bool>("use_cudnn",                                             \
                    "(bool, default false) Only used in cudnn kernel, need " \
                    "install cudnn")                                         \
          .SetDefault(false);                                                \
      AddAttr<bool>(                                                         \
          "is_test",                                                         \
          "(bool, default false) Set to true for inference only, false "     \
          "for training. Some layers may run faster when this is true.")     \
          .SetDefault(false);                                                \
      AddComment(OP_COMMENT);                                                \
    }                                                                        \
D
dzhwinter 已提交
72
  }
D
dzhwinter 已提交
73

74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
template <ActBwdOpFwdDeps kDepValue>
class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType(ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetAttrMap(Attrs());

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
      op->SetInput("X", Input("X"));
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
      op->SetInput("Out", Output("Out"));
    }

    return op;
D
dzhwinter 已提交
98
  }
99
};
D
dzhwinter 已提交
100

101 102 103 104
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
105
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
106 107 108 109 110 111 112 113 114 115
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
// and re-enable this in the future
// #ifdef PADDLE_WITH_CUDA
//   auto it1 = oper.Attrs().find("use_cudnn");
//   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
//     library = framework::LibraryType::kCUDNN;
//   }
// #endif
116 117 118 119 120
#ifdef PADDLE_WITH_MKLDNN
  auto it = oper.Attrs().find("use_mkldnn");
  if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
121
    layout = framework::DataLayout::kMKLDNN;
122 123 124
  }
#endif
  return framework::OpKernelType(
C
chengduo 已提交
125 126
      framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout,
      library);
127 128
}

Q
qijun 已提交
129 130 131 132
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

133
  void InferShape(framework::InferShapeContext* ctx) const override {
134
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
135
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
136
  }
137

138
 protected:
139 140 141 142
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
143 144
};

C
chengduo 已提交
145 146 147 148 149 150
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
  std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
      const override {
    return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
151 152 153
  }
};

Q
qijun 已提交
154 155 156 157
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

158
  void InferShape(framework::InferShapeContext* ctx) const override {
159 160 161
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
162
  }
163

164
 protected:
165 166
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
167
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
168
  }
Q
qijun 已提交
169 170
};

D
dzhwinter 已提交
171
UNUSED constexpr char SigmoidDoc[] = R"DOC(
172
Sigmoid Activation Operator
K
Kexin Zhao 已提交
173

174
$$out = \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
175

D
dzhwinter 已提交
176
)DOC";
Q
qijun 已提交
177

D
dzhwinter 已提交
178
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
179
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
180

181
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
182

D
dzhwinter 已提交
183
)DOC";
184

D
dzhwinter 已提交
185
UNUSED constexpr char ExpDoc[] = R"DOC(
K
kexinzhao 已提交
186
Exp Activation Operator.
K
Kexin Zhao 已提交
187

F
fengjiayi 已提交
188
$out = e^x$
K
Kexin Zhao 已提交
189

D
dzhwinter 已提交
190
)DOC";
Q
qijun 已提交
191

D
dzhwinter 已提交
192
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
193
Relu Activation Operator.
K
Kexin Zhao 已提交
194

F
fengjiayi 已提交
195
$out = \max(x, 0)$
K
Kexin Zhao 已提交
196

D
dzhwinter 已提交
197
)DOC";
K
Kexin Zhao 已提交
198

C
Clementine 已提交
199 200 201 202 203 204 205
UNUSED constexpr char GeluDoc[] = R"DOC(
Gelu Activation Operator.

$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$

)DOC";

D
dzhwinter 已提交
206
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
207
Tanh Activation Operator.
K
Kexin Zhao 已提交
208

Q
update  
qiaolongfei 已提交
209
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
210

D
dzhwinter 已提交
211
)DOC";
212

D
dzhwinter 已提交
213
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
214
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
215

Y
Yan Chunwei 已提交
216
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
217

D
dzhwinter 已提交
218
)DOC";
K
Kexin Zhao 已提交
219

D
dzhwinter 已提交
220
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
221
Sqrt Activation Operator.
K
Kexin Zhao 已提交
222

223 224 225
Please make sure legal input, when input a negative value closed to zero,
you should add a small epsilon(1e-12) to avoid negative number caused by numerical errors.

F
fengjiayi 已提交
226
$out = \sqrt{x}$
K
Kexin Zhao 已提交
227

D
dzhwinter 已提交
228
)DOC";
229

Z
zhoukunsheng 已提交
230 231 232 233 234 235 236 237 238
UNUSED constexpr char RsqrtDoc[] = R"DOC(
Rsqrt Activation Operator.

Please make sure input is legal in case of numeric errors.

$out = \frac{1}{\sqrt{x}}$

)DOC";

D
dzhwinter 已提交
239
UNUSED constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
240
Abs Activation Operator.
K
Kexin Zhao 已提交
241

F
fengjiayi 已提交
242
$out = |x|$
K
Kexin Zhao 已提交
243

D
dzhwinter 已提交
244
)DOC";
245

D
dzhwinter 已提交
246
UNUSED constexpr char CeilDoc[] = R"DOC(
D
dzhwinter 已提交
247 248
Ceil Activation Operator.

249
$out = \left \lceil x \right \rceil$
D
dzhwinter 已提交
250

D
dzhwinter 已提交
251
)DOC";
D
dzhwinter 已提交
252

D
dzhwinter 已提交
253
UNUSED constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
254 255
Floor Activation Operator.

256
$out = \left \lfloor x \right \rfloor$
D
dzhwinter 已提交
257

D
dzhwinter 已提交
258
)DOC";
D
dzhwinter 已提交
259

D
dzhwinter 已提交
260
UNUSED constexpr char CosDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
261
Cosine Activation Operator.
C
add cos  
chengduoZH 已提交
262 263 264

$out = cos(x)$

D
dzhwinter 已提交
265
)DOC";
C
add cos  
chengduoZH 已提交
266

D
dzhwinter 已提交
267
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
268 269 270 271
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
272
)DOC";
C
add sin  
chengduoZH 已提交
273

D
dzhwinter 已提交
274
UNUSED constexpr char RoundDoc[] = R"DOC(
D
dzhwinter 已提交
275 276
Round Activation Operator.

F
fengjiayi 已提交
277
$out = [x]$
D
dzhwinter 已提交
278

D
dzhwinter 已提交
279
)DOC";
D
dzhwinter 已提交
280

D
dzhwinter 已提交
281
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
282
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
283

284
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
285

D
dzhwinter 已提交
286
)DOC";
287

D
dzhwinter 已提交
288
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
289
Log Activation Operator.
K
Kexin Zhao 已提交
290

F
fengjiayi 已提交
291
$out = \ln(x)$
K
Kexin Zhao 已提交
292 293 294

Natural logarithm of x.

D
dzhwinter 已提交
295 296
)DOC";

D
dzhwinter 已提交
297
UNUSED constexpr char SquareDoc[] = R"DOC(
D
dzhwinter 已提交
298 299 300
Square Activation Operator.

$out = x^2$
301

D
dzhwinter 已提交
302 303
)DOC";

D
dzhwinter 已提交
304
UNUSED constexpr char SoftplusDoc[] = R"DOC(
D
dzhwinter 已提交
305 306 307 308 309 310
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

D
dzhwinter 已提交
311
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
312 313
Softsign Activation Operator.

314
$$out = \\frac{x}{1 + \|x\|}$$
D
dzhwinter 已提交
315 316 317

)DOC";

T
tink2123 已提交
318 319 320 321 322 323
class AcosOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of acos operator");
    AddOutput("Out", "Output of acos operator");
    AddComment(R"DOC(
324 325
Arccosine Activation Operator.

T
tink2123 已提交
326
$$out = \cos^{-1}(x)$$
327

T
tink2123 已提交
328 329 330
)DOC");
  }
};
331

T
tink2123 已提交
332 333 334 335 336 337
class AsinOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of asin operator");
    AddOutput("Out", "Output of asin operator");
    AddComment(R"DOC(
338 339
Arcsine Activation Operator.

T
tink2123 已提交
340
$$out = \sin^{-1}(x)$$
341

T
tink2123 已提交
342 343 344
)DOC");
  }
};
345

T
tink2123 已提交
346 347 348 349 350 351
class AtanOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of atan operator");
    AddOutput("Out", "Output of atan operator");
    AddComment(R"DOC(
352 353
Arctanh Activation Operator.

T
tink2123 已提交
354
$$out = \tanh^{-1}(x)$$
355

T
tink2123 已提交
356 357 358
)DOC");
  }
};
359

D
dzhwinter 已提交
360
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
361
 public:
Y
Yu Yang 已提交
362
  void Make() override {
D
dzhwinter 已提交
363 364 365
    AddInput("X", "Input of LeakyRelu operator");
    AddOutput("Out", "Output of LeakyRelu operator");
    AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
A
Adam 已提交
366 367 368 369 370 371 372
    AddAttr<bool>("use_mkldnn",
                  "(bool, default false) Only used in mkldnn kernel")
        .SetDefault(false);
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
K
Kexin Zhao 已提交
373
    AddComment(R"DOC(
D
dzhwinter 已提交
374
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
375

D
dzhwinter 已提交
376
$out = \max(x, \alpha * x)$
K
Kexin Zhao 已提交
377 378

)DOC");
379 380 381
  }
};

D
dzhwinter 已提交
382
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
383
 public:
Y
Yu Yang 已提交
384
  void Make() override {
D
dzhwinter 已提交
385 386 387
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
388
    AddComment(R"DOC(
389 390 391
:strong:`Softshrink Activation Operator`

..  math::
392
    out = \begin{cases}
393 394 395 396
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
397 398

)DOC");
K
kexinzhao 已提交
399 400 401
  }
};

D
dzhwinter 已提交
402
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
403
 public:
Y
Yu Yang 已提交
404
  void Make() override {
D
dzhwinter 已提交
405 406
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
407 408
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
409
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
410
    AddComment(R"DOC(
Y
yuyang18 已提交
411
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
412

Y
yuyang18 已提交
413 414 415 416 417 418
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
419 420

)DOC");
421 422 423
  }
};

424 425
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
426
  void Make() override {
427
    AddInput("X", "Input of BRelu operator");
F
fengjiayi 已提交
428
    AddOutput("Out", "Output of BRelu operator");
429 430 431 432
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
433
    AddComment(R"DOC(
K
kexinzhao 已提交
434
BRelu Activation Operator.
K
Kexin Zhao 已提交
435

F
fengjiayi 已提交
436
$out = \max(\min(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
437 438

)DOC");
439 440 441 442 443
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
444
  void Make() override {
445
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
446
    AddOutput("Out", "Output of SoftRelu operator");
447 448
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
449
    AddComment(R"DOC(
K
kexinzhao 已提交
450
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
451

T
tensor-tang 已提交
452
$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$
K
Kexin Zhao 已提交
453 454

)DOC");
455 456 457
  }
};

458 459
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
460
  void Make() override {
K
Kexin Zhao 已提交
461
    AddInput("X", "Input of ELU operator");
F
fengjiayi 已提交
462
    AddOutput("Out", "Output of ELU operator");
463
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
464
    AddComment(R"DOC(
K
kexinzhao 已提交
465
ELU Activation Operator.
K
Kexin Zhao 已提交
466 467 468 469

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
470
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
471 472

)DOC");
473 474 475
  }
};

476 477
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
478
  void Make() override {
479
    AddInput("X", "Input of Relu6 operator");
F
fengjiayi 已提交
480
    AddOutput("Out", "Output of Relu6 operator");
481 482
    AddAttr<float>("threshold", "The threshold value of Relu6")
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
483
    AddComment(R"DOC(
K
kexinzhao 已提交
484
Relu6 Activation Operator.
K
Kexin Zhao 已提交
485

F
fengjiayi 已提交
486
$out = \min(\max(0, x), 6)$
K
Kexin Zhao 已提交
487 488

)DOC");
489 490 491
  }
};

492 493
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
494
  void Make() override {
495
    AddInput("X", "Input of Pow operator");
F
fengjiayi 已提交
496
    AddOutput("Out", "Output of Pow operator");
497
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
498
    AddComment(R"DOC(
K
kexinzhao 已提交
499
Pow Activation Operator.
K
Kexin Zhao 已提交
500

F
fengjiayi 已提交
501
$out = x^{factor}$
K
Kexin Zhao 已提交
502 503

)DOC");
504 505 506 507 508
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
509
  void Make() override {
510
    AddInput("X", "Input of STanh operator");
F
fengjiayi 已提交
511
    AddOutput("Out", "Output of STanh operator");
512 513 514 515
    AddAttr<float>("scale_a", "The scale parameter of a for the input")
        .SetDefault(2.0f / 3.0f);
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
516
    AddComment(R"DOC(
K
kexinzhao 已提交
517
STanh Activation Operator.
K
Kexin Zhao 已提交
518

Y
Yan Chunwei 已提交
519
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
520 521

)DOC");
Q
qijun 已提交
522 523 524
  }
};

525 526
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
527
  void Make() override {
528
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
529
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
530 531
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
532
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
533
    AddComment(R"DOC(
Y
yuyang18 已提交
534
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
535

Y
yuyang18 已提交
536
..  math::
K
Kexin Zhao 已提交
537

Y
yuyang18 已提交
538
    out = \begin{cases}
Y
yuyang18 已提交
539
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
540 541
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
542
)DOC");
543 544 545
  }
};

546 547
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
548
  void Make() override {
549
    AddInput("X", "Input of HardSigmoid operator");
F
fengjiayi 已提交
550
    AddOutput("Out", "Output of HardSigmoid operator");
551 552 553 554
    AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
        .SetDefault(0.2f);
    AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
        .SetDefault(0.5f);
555
    AddComment(R"DOC(
K
kexinzhao 已提交
556
HardSigmoid Activation Operator.
557

558
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
K
Kexin Zhao 已提交
559
which is much faster than sigmoid.
560

F
fengjiayi 已提交
561
$out = \max(0, \min(1, slope * x + shift))$
562 563

The slope should be positive. The offset can be either positive or negative.
K
Kexin Zhao 已提交
564
The default slope and shift are set according to the above reference.
565 566
It is recommended to use the defaults for this activation.

K
Kexin Zhao 已提交
567
)DOC");
568 569 570
  }
};

A
Abhinav Arora 已提交
571 572
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
573
  void Make() override {
A
Abhinav Arora 已提交
574
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
575
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
576 577 578 579
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

F
fengjiayi 已提交
580
$$out = \\frac{x}{1 + e^{- \beta x}}$$
A
Abhinav Arora 已提交
581 582 583 584 585

)DOC");
  }
};

D
dzhwinter 已提交
586 587 588 589
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
C
Clementine 已提交
590
REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc);
D
dzhwinter 已提交
591 592 593
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
Z
zhoukunsheng 已提交
594
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
D
dzhwinter 已提交
595 596 597 598 599 600 601 602 603 604 605 606
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

607
template <ActBwdOpFwdDeps kDepValue>
608 609 610 611 612
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
613
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
614
      if (ctx->HasOutput("DX")) {
615 616 617
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
618
      if (ctx->HasOutput("DDOut")) {
619 620 621
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
622
    }
623
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
624
      if (ctx->HasOutput("DOut")) {
625 626 627
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
      if (ctx->HasOutput("DDOut")) {
656 657 658
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
659 660 661 662 663 664 665 666 667 668
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

669 670 671 672 673 674 675 676 677 678 679 680 681 682
//
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
//
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("relu_grad_grad");
    // input1: Out
    op->SetInput("Out", Input("Out"));
Q
qingqing01 已提交
683
    // input2: ddx
684 685
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
686
    // output: ddy
687 688 689 690 691
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

692 693 694 695 696 697 698 699 700 701 702 703 704
// leaky_relu Grad: dx=dy if y>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx
class LeakyReluDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("leaky_relu_grad_grad");
    // input1: X
    op->SetInput("X", Input("X"));
A
Adam 已提交
705 706
    // input2: Out
    op->SetInput("Out", Input("Out"));
707 708 709 710 711 712 713 714 715
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

L
lvmengsi 已提交
716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("sqrt_grad_grad");
    op->SetInput("Out", Input("Out"));
    op->SetInput("DX", Output(framework::GradVarName("X")));
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    op->SetOutput("DOut", InputGrad("Out"));
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
class SquareDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("square_grad_grad");
    op->SetInput("X", Input("X"));
    // Out@GRAD: dy
    op->SetInput("DOut", Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));

    op->SetAttrMap(Attrs());

    // X@GRAD: dx
    op->SetOutput("DX", InputGrad("X"));
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

763 764 765 766 767 768 769 770
class ActivationGradOpInplaceInference : public framework::InplaceOpInference {
 public:
  std::unordered_map<std::string, std::string> operator()(
      const framework::OpDesc& op_desc, bool use_cuda) const override {
    return {{framework::GradVarName("Out"), framework::GradVarName("X")}};
  }
};

Q
qijun 已提交
771 772 773 774
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
775
namespace plat = paddle::platform;
776

777 778 779 780 781 782 783 784
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
      KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker,                \
      ops::ActivationOpInferVarType,                                        \
      ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>,  \
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
                       ::paddle::framework::SingleOpInplaceInToOut,         \
                       void>::type);                                        \
785 786
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad,              \
                    ops::ActivationGradOpInplaceInference);
787 788 789

#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor,        \
                                       grad_functor)                      \
Q
QI JUN 已提交
790 791 792 793 794 795 796 797 798 799
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
800
                                ops::grad_functor<double>>);
801

802 803
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
804

805
/* ==========================    relu register  ============================= */
806 807 808 809 810
REGISTER_OPERATOR(
    relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
811
                  ops::ActivationGradOpInplaceInference,
812
                  ops::ReluDoubleGradMaker);
813 814
REGISTER_OPERATOR(
    relu_grad_grad,
815
    ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>);
816 817 818 819 820 821 822 823 824 825 826

REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);

REGISTER_OP_CPU_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>);
827
/* ========================================================================== */
828

829
/* ======================== leaky relu register  ============================ */
830 831 832 833 834 835
REGISTER_OPERATOR(
    leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
836
                  ops::ActivationGradOpInplaceInference,
837
                  ops::LeakyReluDoubleGradMaker);
838 839
REGISTER_OPERATOR(
    leaky_relu_grad_grad,
840
    ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>);
841

842 843 844 845 846 847 848 849 850 851
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
                               LeakyReluGradFunctor);
REGISTER_OP_CPU_KERNEL(
    leaky_relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<
        plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
852 853
/* ========================================================================== */

L
lvmengsi 已提交
854 855 856 857 858 859
/* ===========================   sqrt register  ============================= */
REGISTER_OPERATOR(
    sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SqrtGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
860
                  ops::ActivationGradOpInplaceInference,
L
lvmengsi 已提交
861 862 863 864 865 866 867 868 869 870 871 872 873 874
                  ops::SqrtDoubleGradMaker);
REGISTER_OPERATOR(
    sqrt_grad_grad,
    ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
    sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

875 876 877 878 879 880 881
/* ==========================   square register  ============================ */
REGISTER_OPERATOR(
    square, ops::ActivationOp, ops::SquareOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SquareGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
882
                  ops::ActivationGradOpInplaceInference,
883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899
                  ops::SquareDoubleGradMaker);
REGISTER_OPERATOR(
    square_grad_grad,
    ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>);

REGISTER_ACTIVATION_CPU_KERNEL(square, Square, SquareFunctor,
                               SquareGradFunctor);

REGISTER_OP_CPU_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<plat::float16>>);
/* ========================================================================== */