activation_op.cc 16.9 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
D
dzhwinter 已提交
16
#include <string>
K
Krzysztof Binias 已提交
17
#include "paddle/fluid/operators/mkldnn_activation_op.h"
Q
qijun 已提交
18 19 20 21

namespace paddle {
namespace operators {

D
dzhwinter 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)                  \
  class OP_NAME##OpMaker                                                   \
      : public ::paddle::framework::OpProtoAndCheckerMaker {               \
   public:                                                                 \
    OP_NAME##OpMaker(OpProto *proto, OpAttrChecker *op_checker)            \
        : ::paddle::framework::OpProtoAndCheckerMaker(proto, op_checker) { \
      AddInput("X", "Input of " #OP_NAME "operator");                      \
      AddOutput("Out", "Output of" #OP_NAME "operator");                   \
      AddAttr<bool>("use_mkldnn",                                          \
                    "(bool, default false) Only used in mkldnn kernel")    \
          .SetDefault(false);                                              \
      AddComment(#OP_COMMENT);                                             \
    }                                                                      \
D
dzhwinter 已提交
35
  }
D
dzhwinter 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE)              \
  class OP_NAME##GradMaker                                                   \
      : public ::paddle::framework::SingleGradOpDescMaker {                  \
   public:                                                                   \
    using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
                                                                             \
   protected:                                                                \
    std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {    \
      auto *op = new ::paddle::framework::OpDesc();                          \
      op->SetType(#KERNEL_TYPE "_grad");                                     \
      op->SetInput("Out", Output("Out"));                                    \
      op->SetInput(::paddle::framework::GradVarName("Out"),                  \
                   OutputGrad("Out"));                                       \
                                                                             \
      op->SetAttrMap(Attrs());                                               \
                                                                             \
      op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X"));  \
      return std::unique_ptr<::paddle::framework::OpDesc>(op);               \
    }                                                                        \
D
dzhwinter 已提交
56
  }
D
dzhwinter 已提交
57

Q
qijun 已提交
58 59 60 61
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

62
  void InferShape(framework::InferShapeContext *ctx) const override {
F
fengjiayi 已提交
63 64
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
65
  }
Q
qijun 已提交
66 67
};

Q
qijun 已提交
68 69 70 71
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

72
  void InferShape(framework::InferShapeContext *ctx) const override {
F
fengjiayi 已提交
73
    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
Q
qijun 已提交
74 75 76
  }
};

D
dzhwinter 已提交
77
constexpr char SigmoidDoc[] = R"DOC(
78
Sigmoid Activation Operator
K
Kexin Zhao 已提交
79

F
fengjiayi 已提交
80
$$out = \frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
81

D
dzhwinter 已提交
82
)DOC";
Q
qijun 已提交
83

D
dzhwinter 已提交
84
constexpr char LogSigmoidDoc[] = R"DOC(
85
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
86

F
fengjiayi 已提交
87
$$out = \log \frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
88

D
dzhwinter 已提交
89
)DOC";
90

D
dzhwinter 已提交
91
constexpr char ExpDoc[] = R"DOC(
K
kexinzhao 已提交
92
Exp Activation Operator.
K
Kexin Zhao 已提交
93

F
fengjiayi 已提交
94
$out = e^x$
K
Kexin Zhao 已提交
95

D
dzhwinter 已提交
96
)DOC";
Q
qijun 已提交
97

D
dzhwinter 已提交
98
constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
99
Relu Activation Operator.
K
Kexin Zhao 已提交
100

F
fengjiayi 已提交
101
$out = \max(x, 0)$
K
Kexin Zhao 已提交
102

D
dzhwinter 已提交
103
)DOC";
K
Kexin Zhao 已提交
104

D
dzhwinter 已提交
105
constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
106
Tanh Activation Operator.
K
Kexin Zhao 已提交
107

F
fengjiayi 已提交
108
$$out = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
109

D
dzhwinter 已提交
110
)DOC";
111

D
dzhwinter 已提交
112
constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
113
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
114

F
fengjiayi 已提交
115
$$out = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
116

D
dzhwinter 已提交
117
)DOC";
K
Kexin Zhao 已提交
118

D
dzhwinter 已提交
119
constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
120
Sqrt Activation Operator.
K
Kexin Zhao 已提交
121

F
fengjiayi 已提交
122
$out = \sqrt{x}$
K
Kexin Zhao 已提交
123

D
dzhwinter 已提交
124
)DOC";
125

D
dzhwinter 已提交
126
constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
127
Abs Activation Operator.
K
Kexin Zhao 已提交
128

F
fengjiayi 已提交
129
$out = |x|$
K
Kexin Zhao 已提交
130

D
dzhwinter 已提交
131
)DOC";
132

D
dzhwinter 已提交
133
constexpr char CeilDoc[] = R"DOC(
D
dzhwinter 已提交
134 135
Ceil Activation Operator.

F
fengjiayi 已提交
136
$out = ceil(x)$
D
dzhwinter 已提交
137

D
dzhwinter 已提交
138
)DOC";
D
dzhwinter 已提交
139

D
dzhwinter 已提交
140
constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
141 142
Floor Activation Operator.

F
fengjiayi 已提交
143
$out = floor(x)$
D
dzhwinter 已提交
144

D
dzhwinter 已提交
145
)DOC";
D
dzhwinter 已提交
146

D
dzhwinter 已提交
147
constexpr char CosDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
148
Cosine Activation Operator.
C
add cos  
chengduoZH 已提交
149 150 151

$out = cos(x)$

D
dzhwinter 已提交
152
)DOC";
C
add cos  
chengduoZH 已提交
153

D
dzhwinter 已提交
154
constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
155 156 157 158
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
159
)DOC";
C
add sin  
chengduoZH 已提交
160

D
dzhwinter 已提交
161
constexpr char RoundDoc[] = R"DOC(
D
dzhwinter 已提交
162 163
Round Activation Operator.

F
fengjiayi 已提交
164
$out = [x]$
D
dzhwinter 已提交
165

D
dzhwinter 已提交
166
)DOC";
D
dzhwinter 已提交
167

D
dzhwinter 已提交
168
constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
169
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
170

F
fengjiayi 已提交
171
$$out = \frac{1}{x}$$
K
Kexin Zhao 已提交
172

D
dzhwinter 已提交
173
)DOC";
174

D
dzhwinter 已提交
175
constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
176
Log Activation Operator.
K
Kexin Zhao 已提交
177

F
fengjiayi 已提交
178
$out = \ln(x)$
K
Kexin Zhao 已提交
179 180 181

Natural logarithm of x.

D
dzhwinter 已提交
182 183 184 185 186 187
)DOC";

constexpr char SquareDoc[] = R"DOC(
Square Activation Operator.

$out = x^2$
188

D
dzhwinter 已提交
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
)DOC";

constexpr char SoftplusDoc[] = R"DOC(
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

constexpr char SoftsignDoc[] = R"DOC(
Softsign Activation Operator.

$$out = \frac{x}{1 + |x|}$$

)DOC";

class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
206
 public:
D
dzhwinter 已提交
207
  LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
208
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
D
dzhwinter 已提交
209 210 211
    AddInput("X", "Input of LeakyRelu operator");
    AddOutput("Out", "Output of LeakyRelu operator");
    AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
K
Kexin Zhao 已提交
212
    AddComment(R"DOC(
D
dzhwinter 已提交
213
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
214

D
dzhwinter 已提交
215
$out = \max(x, \alpha * x)$
K
Kexin Zhao 已提交
216 217

)DOC");
218 219 220
  }
};

D
dzhwinter 已提交
221
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
222
 public:
D
dzhwinter 已提交
223
  SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
224
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
D
dzhwinter 已提交
225 226 227
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
228
    AddComment(R"DOC(
D
dzhwinter 已提交
229
Softshrink Activation Operator.
K
Kexin Zhao 已提交
230

D
dzhwinter 已提交
231 232 233 234 235 236 237
$$
out = \begin{cases} 
    x - \lambda, \text{if } x > \lambda \\
    x + \lambda, \text{if } x < -\lambda \\
    0,  \text{otherwise}
    \end{cases}
$$
K
Kexin Zhao 已提交
238 239

)DOC");
K
kexinzhao 已提交
240 241 242
  }
};

D
dzhwinter 已提交
243
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
244
 public:
D
dzhwinter 已提交
245
  HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
246
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
D
dzhwinter 已提交
247 248 249 250
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
    AddAttr<float>("threshold", "The value of threshold for HardShrink")
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
251
    AddComment(R"DOC(
D
dzhwinter 已提交
252
HardShrink Activation Operator.
K
Kexin Zhao 已提交
253

D
dzhwinter 已提交
254 255 256 257 258 259 260
$$
out = \begin{cases} 
    x, \text{if } x > \lambda \\
    x, \text{if } x < -\lambda \\
    0,  \text{otherwise}
    \end{cases}
$$
K
Kexin Zhao 已提交
261 262

)DOC");
263 264 265
  }
};

266 267
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
268 269
  BReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
270
    AddInput("X", "Input of BRelu operator");
F
fengjiayi 已提交
271
    AddOutput("Out", "Output of BRelu operator");
272 273 274 275
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
276
    AddComment(R"DOC(
K
kexinzhao 已提交
277
BRelu Activation Operator.
K
Kexin Zhao 已提交
278

F
fengjiayi 已提交
279
$out = \max(\min(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
280 281

)DOC");
282 283 284 285 286
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
287 288
  SoftReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
289
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
290
    AddOutput("Out", "Output of SoftRelu operator");
291 292
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
293
    AddComment(R"DOC(
K
kexinzhao 已提交
294
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
295

F
fengjiayi 已提交
296
$out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
K
Kexin Zhao 已提交
297 298

)DOC");
299 300 301
  }
};

302 303
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
304 305
  ELUOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
K
Kexin Zhao 已提交
306
    AddInput("X", "Input of ELU operator");
F
fengjiayi 已提交
307
    AddOutput("Out", "Output of ELU operator");
308
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
309
    AddComment(R"DOC(
K
kexinzhao 已提交
310
ELU Activation Operator.
K
Kexin Zhao 已提交
311 312 313 314

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
315
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
316 317

)DOC");
318 319 320
  }
};

321 322
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
323 324
  Relu6OpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
325
    AddInput("X", "Input of Relu6 operator");
F
fengjiayi 已提交
326
    AddOutput("Out", "Output of Relu6 operator");
327 328
    AddAttr<float>("threshold", "The threshold value of Relu6")
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
329
    AddComment(R"DOC(
K
kexinzhao 已提交
330
Relu6 Activation Operator.
K
Kexin Zhao 已提交
331

F
fengjiayi 已提交
332
$out = \min(\max(0, x), 6)$
K
Kexin Zhao 已提交
333 334

)DOC");
335 336 337
  }
};

338 339
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
340 341
  PowOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
342
    AddInput("X", "Input of Pow operator");
F
fengjiayi 已提交
343
    AddOutput("Out", "Output of Pow operator");
344
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
345
    AddComment(R"DOC(
K
kexinzhao 已提交
346
Pow Activation Operator.
K
Kexin Zhao 已提交
347

F
fengjiayi 已提交
348
$out = x^{factor}$
K
Kexin Zhao 已提交
349 350

)DOC");
351 352 353 354 355
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
356 357
  STanhOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
358
    AddInput("X", "Input of STanh operator");
F
fengjiayi 已提交
359
    AddOutput("Out", "Output of STanh operator");
360 361 362 363
    AddAttr<float>("scale_a", "The scale parameter of a for the input")
        .SetDefault(2.0f / 3.0f);
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
364
    AddComment(R"DOC(
K
kexinzhao 已提交
365
STanh Activation Operator.
K
Kexin Zhao 已提交
366

F
fengjiayi 已提交
367
$$out = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
368 369

)DOC");
Q
qijun 已提交
370 371 372
  }
};

373 374
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
375 376
  ThresholdedReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
377
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
378
    AddOutput("Out", "Output of ThresholdedRelu operator");
379 380
    AddAttr<float>("threshold", "The threshold location of activation")
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
381
    AddComment(R"DOC(
K
kexinzhao 已提交
382
ThresholdedRelu Activation Operator.
K
Kexin Zhao 已提交
383 384

$$
F
fengjiayi 已提交
385
out = \begin{cases} 
K
Kexin Zhao 已提交
386 387 388 389 390 391
    x, \text{if } x > threshold \\
    0,  \text{otherwise}
    \end{cases}
$$

)DOC");
392 393 394
  }
};

395 396
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
397 398
  HardSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
399
    AddInput("X", "Input of HardSigmoid operator");
F
fengjiayi 已提交
400
    AddOutput("Out", "Output of HardSigmoid operator");
401 402 403 404
    AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
        .SetDefault(0.2f);
    AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
        .SetDefault(0.5f);
405
    AddComment(R"DOC(
K
kexinzhao 已提交
406
HardSigmoid Activation Operator.
407

K
Kexin Zhao 已提交
408 409
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391), 
which is much faster than sigmoid.
410

F
fengjiayi 已提交
411
$out = \max(0, \min(1, slope * x + shift))$
412 413

The slope should be positive. The offset can be either positive or negative.
K
Kexin Zhao 已提交
414
The default slope and shift are set according to the above reference.
415 416
It is recommended to use the defaults for this activation.

K
Kexin Zhao 已提交
417
)DOC");
418 419 420
  }
};

A
Abhinav Arora 已提交
421 422
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
423 424
  SwishOpMaker(OpProto *proto, OpAttrChecker *op_checker)
      : framework::OpProtoAndCheckerMaker(proto, op_checker) {
A
Abhinav Arora 已提交
425
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
426
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
427 428 429 430
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

F
fengjiayi 已提交
431
$$out = \frac{x}{1 + e^{- \beta x}}$$
A
Abhinav Arora 已提交
432 433 434 435 436

)DOC");
  }
};

D
dzhwinter 已提交
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

D
dzhwinter 已提交
456 457
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu);
D
dzhwinter 已提交
458
REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp);
D
dzhwinter 已提交
459 460 461
REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Floor, floor);
D
dzhwinter 已提交
462
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sqrt, sqrt);
D
dzhwinter 已提交
463
REGISTER_ACTIVATION_OP_GRAD_MAKER(SoftRelu, soft_relu);
D
dzhwinter 已提交
464 465
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu6, relu6);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Reciprocal, reciprocal);
D
dzhwinter 已提交
466
REGISTER_ACTIVATION_OP_GRAD_MAKER(HardSigmoid, hard_sigmoid);
Q
qijun 已提交
467 468 469 470
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
471

D
dzhwinter 已提交
472
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
D
dzhwinter 已提交
473 474 475 476 477 478 479 480 481 482 483
  __macro(Sigmoid, sigmoid);                 \
  __macro(Relu, relu);                       \
  __macro(Exp, exp);                         \
  __macro(Tanh, tanh);                       \
  __macro(Ceil, ceil);                       \
  __macro(Floor, floor);                     \
  __macro(Sqrt, sqrt);                       \
  __macro(SoftRelu, soft_relu);              \
  __macro(Relu6, relu6);                     \
  __macro(Reciprocal, reciprocal);           \
  __macro(HardSigmoid, hard_sigmoid);
D
dzhwinter 已提交
484 485

#define FOR_EACH_OP_FUNCTOR(__macro) \
D
dzhwinter 已提交
486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511
  __macro(LogSigmoid, logsigmoid);   \
  __macro(SoftShrink, softshrink);   \
  __macro(Abs, abs);                 \
  __macro(Cos, cos);                 \
  __macro(Sin, sin);                 \
  __macro(Round, round);             \
  __macro(Log, log);                 \
  __macro(Square, square);           \
  __macro(BRelu, brelu);             \
  __macro(Pow, pow);                 \
  __macro(STanh, stanh);             \
  __macro(Softplus, softplus);       \
  __macro(Softsign, softsign);       \
  __macro(LeakyRelu, leaky_relu);    \
  __macro(TanhShrink, tanh_shrink);  \
  __macro(ELU, elu);                 \
  __macro(HardShrink, hard_shrink);  \
  __macro(Swish, swish);             \
  __macro(ThresholdedRelu, thresholded_relu);

#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE)        \
  REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
                    ::paddle::operators::OP_NAME##OpMaker,          \
                    ::paddle::operators::OP_NAME##GradMaker);       \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)

D
dzhwinter 已提交
512 513 514 515 516
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE)                    \
  REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp,     \
                    ::paddle::operators::OP_NAME##OpMaker,              \
                    ::paddle::framework::DefaultGradOpDescMaker<true>); \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
A
Abhinav Arora 已提交
517

Q
QI JUN 已提交
518 519 520 521 522 523 524 525 526 527 528
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor)   \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
529
                                ops::grad_functor<double>>);
530

D
dzhwinter 已提交
531
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP);
D
dzhwinter 已提交
532
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP);
533
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);