activation_op.cc 19.7 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

T
Tink_Y 已提交
15 16
#include <memory>
#include <unordered_map>
Y
Yi Wang 已提交
17
#include "paddle/fluid/operators/activation_op.h"
D
dzhwinter 已提交
18
#include <string>
19
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
D
dzhwinter 已提交
20
#include "paddle/fluid/platform/port.h"
21 22 23
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
Q
qijun 已提交
24 25 26 27

namespace paddle {
namespace operators {

28 29
using paddle::framework::Tensor;

30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)                    \
  class OP_NAME##OpMaker                                                     \
      : public ::paddle::framework::OpProtoAndCheckerMaker {                 \
   public:                                                                   \
    void Make() override {                                                   \
      AddInput("X", "Input of " #OP_NAME " operator");                       \
      AddOutput("Out", "Output of " #OP_NAME " operator");                   \
      AddAttr<bool>("use_mkldnn",                                            \
                    "(bool, default false) Only used in mkldnn kernel")      \
          .SetDefault(false);                                                \
      AddAttr<bool>("use_cudnn",                                             \
                    "(bool, default false) Only used in cudnn kernel, need " \
                    "install cudnn")                                         \
          .SetDefault(false);                                                \
      AddAttr<bool>(                                                         \
          "is_test",                                                         \
          "(bool, default false) Set to true for inference only, false "     \
          "for training. Some layers may run faster when this is true.")     \
          .SetDefault(false);                                                \
      AddComment(OP_COMMENT);                                                \
    }                                                                        \
D
dzhwinter 已提交
51
  }
D
dzhwinter 已提交
52 53 54 55 56 57 58 59 60

#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE)              \
  class OP_NAME##GradMaker                                                   \
      : public ::paddle::framework::SingleGradOpDescMaker {                  \
   public:                                                                   \
    using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
                                                                             \
   protected:                                                                \
    std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {    \
61
      auto* op = new ::paddle::framework::OpDesc();                          \
D
dzhwinter 已提交
62 63 64 65 66 67 68 69 70 71
      op->SetType(#KERNEL_TYPE "_grad");                                     \
      op->SetInput("Out", Output("Out"));                                    \
      op->SetInput(::paddle::framework::GradVarName("Out"),                  \
                   OutputGrad("Out"));                                       \
                                                                             \
      op->SetAttrMap(Attrs());                                               \
                                                                             \
      op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X"));  \
      return std::unique_ptr<::paddle::framework::OpDesc>(op);               \
    }                                                                        \
D
dzhwinter 已提交
72
  }
D
dzhwinter 已提交
73

74 75 76 77
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
78
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
79 80 81 82 83 84
#ifdef PADDLE_WITH_CUDA
  auto it1 = oper.Attrs().find("use_cudnn");
  if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
    library = framework::LibraryType::kCUDNN;
  }
#endif
85 86 87 88 89
#ifdef PADDLE_WITH_MKLDNN
  auto it = oper.Attrs().find("use_mkldnn");
  if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
90
    layout = framework::DataLayout::kMKLDNN;
91 92 93
  }
#endif
  return framework::OpKernelType(
C
chengduo 已提交
94 95
      framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout,
      library);
96 97
}

Q
qijun 已提交
98 99 100 101
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

102
  void InferShape(framework::InferShapeContext* ctx) const override {
103
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
104
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
105
  }
106

107
 protected:
108 109 110 111
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
112 113
};

C
chengduo 已提交
114 115 116 117 118 119
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
  std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
      const override {
    return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
120 121 122
  }
};

Q
qijun 已提交
123 124 125 126
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

127
  void InferShape(framework::InferShapeContext* ctx) const override {
128 129
    ctx->ShareDim("Out", framework::GradVarName("X"));
    ctx->ShareLoD("Out", framework::GradVarName("X"));
Q
qijun 已提交
130
  }
131

132
 protected:
133 134 135 136
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "Out");
  }
Q
qijun 已提交
137 138
};

D
dzhwinter 已提交
139
UNUSED constexpr char SigmoidDoc[] = R"DOC(
140
Sigmoid Activation Operator
K
Kexin Zhao 已提交
141

142
$$out = \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
143

D
dzhwinter 已提交
144
)DOC";
Q
qijun 已提交
145

D
dzhwinter 已提交
146
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
147
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
148

149
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
150

D
dzhwinter 已提交
151
)DOC";
152

D
dzhwinter 已提交
153
UNUSED constexpr char ExpDoc[] = R"DOC(
K
kexinzhao 已提交
154
Exp Activation Operator.
K
Kexin Zhao 已提交
155

F
fengjiayi 已提交
156
$out = e^x$
K
Kexin Zhao 已提交
157

D
dzhwinter 已提交
158
)DOC";
Q
qijun 已提交
159

D
dzhwinter 已提交
160
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
161
Relu Activation Operator.
K
Kexin Zhao 已提交
162

F
fengjiayi 已提交
163
$out = \max(x, 0)$
K
Kexin Zhao 已提交
164

D
dzhwinter 已提交
165
)DOC";
K
Kexin Zhao 已提交
166

C
Clementine 已提交
167 168 169 170 171 172 173
UNUSED constexpr char GeluDoc[] = R"DOC(
Gelu Activation Operator.

$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$

)DOC";

D
dzhwinter 已提交
174
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
175
Tanh Activation Operator.
K
Kexin Zhao 已提交
176

Q
update  
qiaolongfei 已提交
177
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
178

D
dzhwinter 已提交
179
)DOC";
180

D
dzhwinter 已提交
181
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
182
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
183

Y
Yan Chunwei 已提交
184
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
185

D
dzhwinter 已提交
186
)DOC";
K
Kexin Zhao 已提交
187

D
dzhwinter 已提交
188
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
189
Sqrt Activation Operator.
K
Kexin Zhao 已提交
190

F
fengjiayi 已提交
191
$out = \sqrt{x}$
K
Kexin Zhao 已提交
192

D
dzhwinter 已提交
193
)DOC";
194

D
dzhwinter 已提交
195
UNUSED constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
196
Abs Activation Operator.
K
Kexin Zhao 已提交
197

F
fengjiayi 已提交
198
$out = |x|$
K
Kexin Zhao 已提交
199

D
dzhwinter 已提交
200
)DOC";
201

D
dzhwinter 已提交
202
UNUSED constexpr char CeilDoc[] = R"DOC(
D
dzhwinter 已提交
203 204
Ceil Activation Operator.

205
$out = \left \lceil x \right \rceil$
D
dzhwinter 已提交
206

D
dzhwinter 已提交
207
)DOC";
D
dzhwinter 已提交
208

D
dzhwinter 已提交
209
UNUSED constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
210 211
Floor Activation Operator.

212
$out = \left \lfloor x \right \rfloor$
D
dzhwinter 已提交
213

D
dzhwinter 已提交
214
)DOC";
D
dzhwinter 已提交
215

D
dzhwinter 已提交
216
UNUSED constexpr char CosDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
217
Cosine Activation Operator.
C
add cos  
chengduoZH 已提交
218 219 220

$out = cos(x)$

D
dzhwinter 已提交
221
)DOC";
C
add cos  
chengduoZH 已提交
222

D
dzhwinter 已提交
223
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
224 225 226 227
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
228
)DOC";
C
add sin  
chengduoZH 已提交
229

D
dzhwinter 已提交
230
UNUSED constexpr char RoundDoc[] = R"DOC(
D
dzhwinter 已提交
231 232
Round Activation Operator.

F
fengjiayi 已提交
233
$out = [x]$
D
dzhwinter 已提交
234

D
dzhwinter 已提交
235
)DOC";
D
dzhwinter 已提交
236

D
dzhwinter 已提交
237
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
238
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
239

240
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
241

D
dzhwinter 已提交
242
)DOC";
243

D
dzhwinter 已提交
244
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
245
Log Activation Operator.
K
Kexin Zhao 已提交
246

F
fengjiayi 已提交
247
$out = \ln(x)$
K
Kexin Zhao 已提交
248 249 250

Natural logarithm of x.

D
dzhwinter 已提交
251 252
)DOC";

D
dzhwinter 已提交
253
UNUSED constexpr char SquareDoc[] = R"DOC(
D
dzhwinter 已提交
254 255 256
Square Activation Operator.

$out = x^2$
257

D
dzhwinter 已提交
258 259
)DOC";

D
dzhwinter 已提交
260
UNUSED constexpr char SoftplusDoc[] = R"DOC(
D
dzhwinter 已提交
261 262 263 264 265 266
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

D
dzhwinter 已提交
267
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
268 269
Softsign Activation Operator.

270
$$out = \\frac{x}{1 + \|x\|}$$
D
dzhwinter 已提交
271 272 273

)DOC";

274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
UNUSED constexpr char AcosDoc[] = R"DOC(
Arccosine Activation Operator.

$${out}_{i} = \cos^{-1}({input}_{i})$$

)DOC";

UNUSED constexpr char AsinDoc[] = R"DOC(
Arcsine Activation Operator.

$out = \sin^{-1}({input}_{i})$

)DOC";

UNUSED constexpr char AtanDoc[] = R"DOC(
Arctanh Activation Operator.

$out = \tanh^{-1}({input}_{i})$

)DOC";

D
dzhwinter 已提交
295
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
296
 public:
Y
Yu Yang 已提交
297
  void Make() override {
D
dzhwinter 已提交
298 299 300
    AddInput("X", "Input of LeakyRelu operator");
    AddOutput("Out", "Output of LeakyRelu operator");
    AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
K
Kexin Zhao 已提交
301
    AddComment(R"DOC(
D
dzhwinter 已提交
302
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
303

D
dzhwinter 已提交
304
$out = \max(x, \alpha * x)$
K
Kexin Zhao 已提交
305 306

)DOC");
307 308 309
  }
};

D
dzhwinter 已提交
310
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
311
 public:
Y
Yu Yang 已提交
312
  void Make() override {
D
dzhwinter 已提交
313 314 315
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
316
    AddComment(R"DOC(
317 318 319
:strong:`Softshrink Activation Operator`

..  math::
320
    out = \begin{cases}
321 322 323 324
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
325 326

)DOC");
K
kexinzhao 已提交
327 328 329
  }
};

D
dzhwinter 已提交
330
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
331
 public:
Y
Yu Yang 已提交
332
  void Make() override {
D
dzhwinter 已提交
333 334
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
335 336
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
337
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
338
    AddComment(R"DOC(
Y
yuyang18 已提交
339
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
340

Y
yuyang18 已提交
341 342 343 344 345 346
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
347 348

)DOC");
349 350 351
  }
};

352 353
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
354
  void Make() override {
355
    AddInput("X", "Input of BRelu operator");
F
fengjiayi 已提交
356
    AddOutput("Out", "Output of BRelu operator");
357 358 359 360
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
361
    AddComment(R"DOC(
K
kexinzhao 已提交
362
BRelu Activation Operator.
K
Kexin Zhao 已提交
363

F
fengjiayi 已提交
364
$out = \max(\min(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
365 366

)DOC");
367 368 369 370 371
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
372
  void Make() override {
373
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
374
    AddOutput("Out", "Output of SoftRelu operator");
375 376
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
377
    AddComment(R"DOC(
K
kexinzhao 已提交
378
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
379

F
fengjiayi 已提交
380
$out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
K
Kexin Zhao 已提交
381 382

)DOC");
383 384 385
  }
};

386 387
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
388
  void Make() override {
K
Kexin Zhao 已提交
389
    AddInput("X", "Input of ELU operator");
F
fengjiayi 已提交
390
    AddOutput("Out", "Output of ELU operator");
391
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
392
    AddComment(R"DOC(
K
kexinzhao 已提交
393
ELU Activation Operator.
K
Kexin Zhao 已提交
394 395 396 397

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
398
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
399 400

)DOC");
401 402 403
  }
};

404 405
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
406
  void Make() override {
407
    AddInput("X", "Input of Relu6 operator");
F
fengjiayi 已提交
408
    AddOutput("Out", "Output of Relu6 operator");
409 410
    AddAttr<float>("threshold", "The threshold value of Relu6")
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
411
    AddComment(R"DOC(
K
kexinzhao 已提交
412
Relu6 Activation Operator.
K
Kexin Zhao 已提交
413

F
fengjiayi 已提交
414
$out = \min(\max(0, x), 6)$
K
Kexin Zhao 已提交
415 416

)DOC");
417 418 419
  }
};

420 421
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
422
  void Make() override {
423
    AddInput("X", "Input of Pow operator");
F
fengjiayi 已提交
424
    AddOutput("Out", "Output of Pow operator");
425
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
426
    AddComment(R"DOC(
K
kexinzhao 已提交
427
Pow Activation Operator.
K
Kexin Zhao 已提交
428

F
fengjiayi 已提交
429
$out = x^{factor}$
K
Kexin Zhao 已提交
430 431

)DOC");
432 433 434 435 436
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
437
  void Make() override {
438
    AddInput("X", "Input of STanh operator");
F
fengjiayi 已提交
439
    AddOutput("Out", "Output of STanh operator");
440 441 442 443
    AddAttr<float>("scale_a", "The scale parameter of a for the input")
        .SetDefault(2.0f / 3.0f);
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
444
    AddComment(R"DOC(
K
kexinzhao 已提交
445
STanh Activation Operator.
K
Kexin Zhao 已提交
446

Y
Yan Chunwei 已提交
447
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
448 449

)DOC");
Q
qijun 已提交
450 451 452
  }
};

453 454
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
455
  void Make() override {
456
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
457
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
458 459
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
460
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
461
    AddComment(R"DOC(
Y
yuyang18 已提交
462
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
463

Y
yuyang18 已提交
464
..  math::
K
Kexin Zhao 已提交
465

Y
yuyang18 已提交
466
    out = \begin{cases}
Y
yuyang18 已提交
467
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
468 469
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
470
)DOC");
471 472 473
  }
};

474 475
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
476
  void Make() override {
477
    AddInput("X", "Input of HardSigmoid operator");
F
fengjiayi 已提交
478
    AddOutput("Out", "Output of HardSigmoid operator");
479 480 481 482
    AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
        .SetDefault(0.2f);
    AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
        .SetDefault(0.5f);
483
    AddComment(R"DOC(
K
kexinzhao 已提交
484
HardSigmoid Activation Operator.
485

486
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
K
Kexin Zhao 已提交
487
which is much faster than sigmoid.
488

F
fengjiayi 已提交
489
$out = \max(0, \min(1, slope * x + shift))$
490 491

The slope should be positive. The offset can be either positive or negative.
K
Kexin Zhao 已提交
492
The default slope and shift are set according to the above reference.
493 494
It is recommended to use the defaults for this activation.

K
Kexin Zhao 已提交
495
)DOC");
496 497 498
  }
};

A
Abhinav Arora 已提交
499 500
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
501
  void Make() override {
A
Abhinav Arora 已提交
502
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
503
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
504 505 506 507
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

F
fengjiayi 已提交
508
$$out = \\frac{x}{1 + e^{- \beta x}}$$
A
Abhinav Arora 已提交
509 510 511 512 513

)DOC");
  }
};

D
dzhwinter 已提交
514 515 516 517
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
C
Clementine 已提交
518
REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc);
D
dzhwinter 已提交
519
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
520
REGISTER_ACTIVATION_OP_MAKER(Atan, AtanDoc);
D
dzhwinter 已提交
521 522 523 524 525 526
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
527
REGISTER_ACTIVATION_OP_MAKER(Acos, AcosDoc);
D
dzhwinter 已提交
528
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
529
REGISTER_ACTIVATION_OP_MAKER(Asin, AsinDoc);
D
dzhwinter 已提交
530 531 532 533 534 535 536
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

D
dzhwinter 已提交
537 538
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu);
C
Clementine 已提交
539
REGISTER_ACTIVATION_OP_GRAD_MAKER(Gelu, gelu);
D
dzhwinter 已提交
540
REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp);
D
dzhwinter 已提交
541 542 543
REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Floor, floor);
D
dzhwinter 已提交
544
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sqrt, sqrt);
D
dzhwinter 已提交
545
REGISTER_ACTIVATION_OP_GRAD_MAKER(SoftRelu, soft_relu);
D
dzhwinter 已提交
546 547
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu6, relu6);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Reciprocal, reciprocal);
D
dzhwinter 已提交
548
REGISTER_ACTIVATION_OP_GRAD_MAKER(HardSigmoid, hard_sigmoid);
Q
qijun 已提交
549 550 551 552
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
553

D
dzhwinter 已提交
554
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
D
dzhwinter 已提交
555
  __macro(Sigmoid, sigmoid);                 \
556
  __macro(Relu, relu);                       \
D
dzhwinter 已提交
557
  __macro(Exp, exp);                         \
558
  __macro(Tanh, tanh);                       \
D
dzhwinter 已提交
559 560
  __macro(Ceil, ceil);                       \
  __macro(Floor, floor);                     \
561
  __macro(Sqrt, sqrt);                       \
D
dzhwinter 已提交
562 563 564 565
  __macro(SoftRelu, soft_relu);              \
  __macro(Relu6, relu6);                     \
  __macro(Reciprocal, reciprocal);           \
  __macro(HardSigmoid, hard_sigmoid);
D
dzhwinter 已提交
566 567

#define FOR_EACH_OP_FUNCTOR(__macro) \
D
dzhwinter 已提交
568 569
  __macro(LogSigmoid, logsigmoid);   \
  __macro(SoftShrink, softshrink);   \
570
  __macro(Abs, abs);                 \
D
dzhwinter 已提交
571
  __macro(Cos, cos);                 \
572
  __macro(Acos, acos);               \
D
dzhwinter 已提交
573
  __macro(Sin, sin);                 \
574 575
  __macro(Asin, asin);               \
  __macro(Atan, atan);               \
D
dzhwinter 已提交
576 577 578
  __macro(Round, round);             \
  __macro(Log, log);                 \
  __macro(Square, square);           \
C
Clementine 已提交
579
  __macro(Gelu, gelu);               \
D
dzhwinter 已提交
580 581 582 583 584 585 586 587 588 589 590 591
  __macro(BRelu, brelu);             \
  __macro(Pow, pow);                 \
  __macro(STanh, stanh);             \
  __macro(Softplus, softplus);       \
  __macro(Softsign, softsign);       \
  __macro(LeakyRelu, leaky_relu);    \
  __macro(TanhShrink, tanh_shrink);  \
  __macro(ELU, elu);                 \
  __macro(HardShrink, hard_shrink);  \
  __macro(Swish, swish);             \
  __macro(ThresholdedRelu, thresholded_relu);

D
dzhwinter 已提交
592 593 594 595 596 597 598 599
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE)                   \
  REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp,            \
                    ::paddle::operators::OP_NAME##OpMaker,                     \
                    ::paddle::operators::ActivationOpInferVarType,             \
                    ::paddle::operators::OP_NAME##GradMaker,                   \
                    ::paddle::framework::SingleOpInplaceInToOut);              \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad, \
                    ::paddle::framework::SingleOpInplaceInToOut)
D
dzhwinter 已提交
600

D
dzhwinter 已提交
601 602 603
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE)                    \
  REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp,     \
                    ::paddle::operators::OP_NAME##OpMaker,              \
604
                    ::paddle::operators::ActivationOpInferVarType,      \
D
dzhwinter 已提交
605 606
                    ::paddle::framework::DefaultGradOpDescMaker<true>); \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
A
Abhinav Arora 已提交
607

Q
QI JUN 已提交
608 609 610 611 612 613 614 615 616 617 618
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor)   \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
619
                                ops::grad_functor<double>>);
620

D
dzhwinter 已提交
621
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP);
D
dzhwinter 已提交
622
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP);
623
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);