activation_op.cc 16.5 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
K
Krzysztof Binias 已提交
16
#include "paddle/fluid/operators/mkldnn_activation_op.h"
Q
qijun 已提交
17 18 19 20

namespace paddle {
namespace operators {

D
dzhwinter 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)               \
  class OP_NAME##OpMaker : public framework::OpProtoAndCheckerMaker {   \
   public:                                                              \
    OP_NAME##OpMaker(OpProto *proto, OpAttrChecker *op_checker)         \
        : framework::OpProtoAndCheckerMaker(proto, op_checker) {        \
      AddInput("X", "Input of " #OP_NAME "operator");                   \
      AddOutput("Out", "Output of" #OP_NAME "operator");                \
      AddAttr<bool>("use_mkldnn",                                       \
                    "(bool, default false) Only used in mkldnn kernel") \
          .SetDefault(false);                                           \
      AddComment(#OP_COMMENT);                                          \
    }                                                                   \
  }

D
dzhwinter 已提交
35
#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE)        \
D
dzhwinter 已提交
36 37
  class OP_NAME##GradMaker : public framework::SingleGradOpDescMaker { \
   public:                                                             \
D
dzhwinter 已提交
38 39
    using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;     \
                                                                       \
D
dzhwinter 已提交
40 41 42
   protected:                                                          \
    std::unique_ptr<framework::OpDesc> Apply() const override {        \
      auto *op = new framework::OpDesc();                              \
D
dzhwinter 已提交
43 44
      op->SetType(#KERNEL_TYPE "_grad");                               \
      op->SetInput("Out", Output("Out"));                              \
D
dzhwinter 已提交
45 46 47 48 49 50 51 52 53
      op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));  \
                                                                       \
      op->SetAttrMap(Attrs());                                         \
                                                                       \
      op->SetOutput(framework::GradVarName("X"), InputGrad("X"));      \
      return std::unique_ptr<framework::OpDesc>(op);                   \
    }                                                                  \
  }

Q
qijun 已提交
54 55 56 57
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

58
  void InferShape(framework::InferShapeContext *ctx) const override {
F
fengjiayi 已提交
59 60
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
61
  }
Q
qijun 已提交
62 63
};

Q
qijun 已提交
64 65 66 67
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

68
  void InferShape(framework::InferShapeContext *ctx) const override {
F
fengjiayi 已提交
69
    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
Q
qijun 已提交
70 71 72
  }
};

D
dzhwinter 已提交
73
constexpr char SigmoidDoc[] = R"DOC(
74
Sigmoid Activation Operator
K
Kexin Zhao 已提交
75

F
fengjiayi 已提交
76
$$out = \frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
77

D
dzhwinter 已提交
78
)DOC";
Q
qijun 已提交
79

D
dzhwinter 已提交
80
constexpr char LogSigmoidDoc[] = R"DOC(
81
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
82

F
fengjiayi 已提交
83
$$out = \log \frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
84

D
dzhwinter 已提交
85
)DOC";
86

D
dzhwinter 已提交
87
constexpr char ExpDoc[] = R"DOC(
K
kexinzhao 已提交
88
Exp Activation Operator.
K
Kexin Zhao 已提交
89

F
fengjiayi 已提交
90
$out = e^x$
K
Kexin Zhao 已提交
91

D
dzhwinter 已提交
92
)DOC";
Q
qijun 已提交
93

D
dzhwinter 已提交
94
constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
95
Relu Activation Operator.
K
Kexin Zhao 已提交
96

F
fengjiayi 已提交
97
$out = \max(x, 0)$
K
Kexin Zhao 已提交
98

D
dzhwinter 已提交
99
)DOC";
K
Kexin Zhao 已提交
100

D
dzhwinter 已提交
101
constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
102
Tanh Activation Operator.
K
Kexin Zhao 已提交
103

F
fengjiayi 已提交
104
$$out = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
105

D
dzhwinter 已提交
106
)DOC";
107

D
dzhwinter 已提交
108
constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
109
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
110

F
fengjiayi 已提交
111
$$out = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
112

D
dzhwinter 已提交
113
)DOC";
K
Kexin Zhao 已提交
114

D
dzhwinter 已提交
115
constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
116
Sqrt Activation Operator.
K
Kexin Zhao 已提交
117

F
fengjiayi 已提交
118
$out = \sqrt{x}$
K
Kexin Zhao 已提交
119

D
dzhwinter 已提交
120
)DOC";
121

D
dzhwinter 已提交
122
constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
123
Abs Activation Operator.
K
Kexin Zhao 已提交
124

F
fengjiayi 已提交
125
$out = |x|$
K
Kexin Zhao 已提交
126

D
dzhwinter 已提交
127
)DOC";
128

D
dzhwinter 已提交
129
constexpr char CeilDoc[] = R"DOC(
D
dzhwinter 已提交
130 131
Ceil Activation Operator.

F
fengjiayi 已提交
132
$out = ceil(x)$
D
dzhwinter 已提交
133

D
dzhwinter 已提交
134
)DOC";
D
dzhwinter 已提交
135

D
dzhwinter 已提交
136
constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
137 138
Floor Activation Operator.

F
fengjiayi 已提交
139
$out = floor(x)$
D
dzhwinter 已提交
140

D
dzhwinter 已提交
141
)DOC";
D
dzhwinter 已提交
142

D
dzhwinter 已提交
143
constexpr char CosDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
144
Cosine Activation Operator.
C
add cos  
chengduoZH 已提交
145 146 147

$out = cos(x)$

D
dzhwinter 已提交
148
)DOC";
C
add cos  
chengduoZH 已提交
149

D
dzhwinter 已提交
150
constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
151 152 153 154
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
155
)DOC";
C
add sin  
chengduoZH 已提交
156

D
dzhwinter 已提交
157
constexpr char RoundDoc[] = R"DOC(
D
dzhwinter 已提交
158 159
Round Activation Operator.

F
fengjiayi 已提交
160
$out = [x]$
D
dzhwinter 已提交
161

D
dzhwinter 已提交
162
)DOC";
D
dzhwinter 已提交
163

D
dzhwinter 已提交
164
constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
165
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
166

F
fengjiayi 已提交
167
$$out = \frac{1}{x}$$
K
Kexin Zhao 已提交
168

D
dzhwinter 已提交
169
)DOC";
170

D
dzhwinter 已提交
171
constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
172
Log Activation Operator.
K
Kexin Zhao 已提交
173

F
fengjiayi 已提交
174
$out = \ln(x)$
K
Kexin Zhao 已提交
175 176 177

Natural logarithm of x.

D
dzhwinter 已提交
178 179 180 181 182 183
)DOC";

constexpr char SquareDoc[] = R"DOC(
Square Activation Operator.

$out = x^2$
184

D
dzhwinter 已提交
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
)DOC";

constexpr char SoftplusDoc[] = R"DOC(
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

constexpr char SoftsignDoc[] = R"DOC(
Softsign Activation Operator.

$$out = \frac{x}{1 + |x|}$$

)DOC";

class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
202
 public:
D
dzhwinter 已提交
203
  LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
204
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
D
dzhwinter 已提交
205 206 207
    AddInput("X", "Input of LeakyRelu operator");
    AddOutput("Out", "Output of LeakyRelu operator");
    AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
K
Kexin Zhao 已提交
208
    AddComment(R"DOC(
D
dzhwinter 已提交
209
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
210

D
dzhwinter 已提交
211
$out = \max(x, \alpha * x)$
K
Kexin Zhao 已提交
212 213

)DOC");
214 215 216
  }
};

D
dzhwinter 已提交
217
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
218
 public:
D
dzhwinter 已提交
219
  SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
220
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
D
dzhwinter 已提交
221 222 223
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
224
    AddComment(R"DOC(
D
dzhwinter 已提交
225
Softshrink Activation Operator.
K
Kexin Zhao 已提交
226

D
dzhwinter 已提交
227 228 229 230 231 232 233
$$
out = \begin{cases} 
    x - \lambda, \text{if } x > \lambda \\
    x + \lambda, \text{if } x < -\lambda \\
    0,  \text{otherwise}
    \end{cases}
$$
K
Kexin Zhao 已提交
234 235

)DOC");
K
kexinzhao 已提交
236 237 238
  }
};

D
dzhwinter 已提交
239
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
240
 public:
D
dzhwinter 已提交
241
  HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
242
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
D
dzhwinter 已提交
243 244 245 246
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
    AddAttr<float>("threshold", "The value of threshold for HardShrink")
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
247
    AddComment(R"DOC(
D
dzhwinter 已提交
248
HardShrink Activation Operator.
K
Kexin Zhao 已提交
249

D
dzhwinter 已提交
250 251 252 253 254 255 256
$$
out = \begin{cases} 
    x, \text{if } x > \lambda \\
    x, \text{if } x < -\lambda \\
    0,  \text{otherwise}
    \end{cases}
$$
K
Kexin Zhao 已提交
257 258

)DOC");
259 260 261
  }
};

262 263
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
264 265
  BReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
266
    AddInput("X", "Input of BRelu operator");
F
fengjiayi 已提交
267
    AddOutput("Out", "Output of BRelu operator");
268 269 270 271
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
272
    AddComment(R"DOC(
K
kexinzhao 已提交
273
BRelu Activation Operator.
K
Kexin Zhao 已提交
274

F
fengjiayi 已提交
275
$out = \max(\min(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
276 277

)DOC");
278 279 280 281 282
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
283 284
  SoftReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
285
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
286
    AddOutput("Out", "Output of SoftRelu operator");
287 288
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
289
    AddComment(R"DOC(
K
kexinzhao 已提交
290
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
291

F
fengjiayi 已提交
292
$out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
K
Kexin Zhao 已提交
293 294

)DOC");
295 296 297
  }
};

298 299
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
300 301
  ELUOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
K
Kexin Zhao 已提交
302
    AddInput("X", "Input of ELU operator");
F
fengjiayi 已提交
303
    AddOutput("Out", "Output of ELU operator");
304
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
305
    AddComment(R"DOC(
K
kexinzhao 已提交
306
ELU Activation Operator.
K
Kexin Zhao 已提交
307 308 309 310

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
311
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
312 313

)DOC");
314 315 316
  }
};

317 318
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
319 320
  Relu6OpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
321
    AddInput("X", "Input of Relu6 operator");
F
fengjiayi 已提交
322
    AddOutput("Out", "Output of Relu6 operator");
323 324
    AddAttr<float>("threshold", "The threshold value of Relu6")
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
325
    AddComment(R"DOC(
K
kexinzhao 已提交
326
Relu6 Activation Operator.
K
Kexin Zhao 已提交
327

F
fengjiayi 已提交
328
$out = \min(\max(0, x), 6)$
K
Kexin Zhao 已提交
329 330

)DOC");
331 332 333
  }
};

334 335
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
336 337
  PowOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
338
    AddInput("X", "Input of Pow operator");
F
fengjiayi 已提交
339
    AddOutput("Out", "Output of Pow operator");
340
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
341
    AddComment(R"DOC(
K
kexinzhao 已提交
342
Pow Activation Operator.
K
Kexin Zhao 已提交
343

F
fengjiayi 已提交
344
$out = x^{factor}$
K
Kexin Zhao 已提交
345 346

)DOC");
347 348 349 350 351
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
352 353
  STanhOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
354
    AddInput("X", "Input of STanh operator");
F
fengjiayi 已提交
355
    AddOutput("Out", "Output of STanh operator");
356 357 358 359
    AddAttr<float>("scale_a", "The scale parameter of a for the input")
        .SetDefault(2.0f / 3.0f);
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
360
    AddComment(R"DOC(
K
kexinzhao 已提交
361
STanh Activation Operator.
K
Kexin Zhao 已提交
362

F
fengjiayi 已提交
363
$$out = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
364 365

)DOC");
Q
qijun 已提交
366 367 368
  }
};

369 370
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
371 372
  ThresholdedReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
373
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
374
    AddOutput("Out", "Output of ThresholdedRelu operator");
375 376
    AddAttr<float>("threshold", "The threshold location of activation")
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
377
    AddComment(R"DOC(
K
kexinzhao 已提交
378
ThresholdedRelu Activation Operator.
K
Kexin Zhao 已提交
379 380

$$
F
fengjiayi 已提交
381
out = \begin{cases} 
K
Kexin Zhao 已提交
382 383 384 385 386 387
    x, \text{if } x > threshold \\
    0,  \text{otherwise}
    \end{cases}
$$

)DOC");
388 389 390
  }
};

391 392
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
393 394
  HardSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
395
    AddInput("X", "Input of HardSigmoid operator");
F
fengjiayi 已提交
396
    AddOutput("Out", "Output of HardSigmoid operator");
397 398 399 400
    AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
        .SetDefault(0.2f);
    AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
        .SetDefault(0.5f);
401
    AddComment(R"DOC(
K
kexinzhao 已提交
402
HardSigmoid Activation Operator.
403

K
Kexin Zhao 已提交
404 405
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391), 
which is much faster than sigmoid.
406

F
fengjiayi 已提交
407
$out = \max(0, \min(1, slope * x + shift))$
408 409

The slope should be positive. The offset can be either positive or negative.
K
Kexin Zhao 已提交
410
The default slope and shift are set according to the above reference.
411 412
It is recommended to use the defaults for this activation.

K
Kexin Zhao 已提交
413
)DOC");
414 415 416
  }
};

A
Abhinav Arora 已提交
417 418
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
419 420
  SwishOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
A
Abhinav Arora 已提交
421
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
422
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
423 424 425 426
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

F
fengjiayi 已提交
427
$$out = \frac{x}{1 + e^{- \beta x}}$$
A
Abhinav Arora 已提交
428 429 430 431 432

)DOC");
  }
};

D
dzhwinter 已提交
433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

// NOTE(*) only gradient can be inplaced need to register its gradient maker,
// To tell the executor which input variable is used. By default, every Input
// variable
// is used in gradient operator.
// The operator name written in lowercase intentionally.
D
dzhwinter 已提交
457 458 459 460 461 462 463 464 465 466 467 468
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sqrt, sqrt);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Floor, floor);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Reciprocal, reciprocal);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu6, relu6);
REGISTER_ACTIVATION_OP_GRAD_MAKER(SoftRelu, soft_relu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(HardSigmoid, hard_sigmoid);

Q
qijun 已提交
469 470 471 472
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
473

D
dzhwinter 已提交
474 475 476 477 478
#define REGISTER_INPLACE_ACTIVATION_OP(act_type, op_name)               \
  REGISTER_OPERATOR(act_type, ops::ActivationOp, ops::op_name##OpMaker, \
                    ops::op_name##GradMaker);                           \
  REGISTER_OPERATOR(act_type##grad, ops::ActivationOpGrad)

D
dzhwinter 已提交
479 480 481 482
#define REGISTER_ACTIVATION_OP(act_type, op_name)                 \
  REGISTER_OP(act_type, ops::ActivationOp, ops::op_name##OpMaker, \
              act_type##_grad, ops::ActivationOpGrad);

D
dzhwinter 已提交
483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
  __macro(sigmoid, Sigmoid);                 \
  __macro(relu, Relu);                       \
  __macro(exp, Exp);                         \
  __macro(tanh, Tanh);                       \
  __macro(ceil, Ceil);                       \
  __macro(floor, Floor);                     \
  __macro(sqrt, Sqrt);                       \
  __macro(soft_relu, SoftRelu);              \
  __macro(relu6, Relu6);                     \
  __macro(reciprocal, Reciprocal);           \
  __macro(hard_sigmoid, HardSigmoid);

#define FOR_EACH_OP_FUNCTOR(__macro) \
  __macro(logsigmoid, LogSigmoid);   \
  __macro(softshrink, SoftShrink);   \
  __macro(abs, Abs);                 \
  __macro(cos, Cos);                 \
  __macro(sin, Sin);                 \
  __macro(round, Round);             \
  __macro(log, Log);                 \
  __macro(square, Square);           \
  __macro(brelu, BRelu);             \
  __macro(pow, Pow);                 \
  __macro(stanh, STanh);             \
  __macro(softplus, Softplus);       \
  __macro(softsign, Softsign);       \
  __macro(leaky_relu, LeakyRelu);    \
  __macro(tanh_shrink, TanhShrink);  \
  __macro(elu, ELU);                 \
  __macro(hard_shrink, HardShrink);  \
  __macro(swish, Swish);             \
D
dzhwinter 已提交
515
  __macro(thresholded_relu, ThresholdedRelu);
A
Abhinav Arora 已提交
516

Q
QI JUN 已提交
517 518 519 520 521 522 523 524 525 526 527
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor)   \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
528
                                ops::grad_functor<double>>);
529

D
dzhwinter 已提交
530
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP);
D
dzhwinter 已提交
531
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP);
532
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);