activation_op.cc 17.1 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
D
dzhwinter 已提交
16
#include <string>
K
Krzysztof Binias 已提交
17
#include "paddle/fluid/operators/mkldnn_activation_op.h"
D
dzhwinter 已提交
18
#include "paddle/fluid/platform/port.h"
Q
qijun 已提交
19 20 21 22

namespace paddle {
namespace operators {

23 24 25 26 27 28 29 30 31 32 33 34 35 36
using paddle::framework::Tensor;

#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)               \
  class OP_NAME##OpMaker                                                \
      : public ::paddle::framework::OpProtoAndCheckerMaker {            \
   public:                                                              \
    void Make() override {                                              \
      AddInput("X", "Input of " #OP_NAME " operator");                  \
      AddOutput("Out", "Output of " #OP_NAME " operator").Reuse("X");   \
      AddAttr<bool>("use_mkldnn",                                       \
                    "(bool, default false) Only used in mkldnn kernel") \
          .SetDefault(false);                                           \
      AddComment(#OP_COMMENT);                                          \
    }                                                                   \
D
dzhwinter 已提交
37
  }
D
dzhwinter 已提交
38 39 40 41 42 43 44 45 46

#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE)              \
  class OP_NAME##GradMaker                                                   \
      : public ::paddle::framework::SingleGradOpDescMaker {                  \
   public:                                                                   \
    using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
                                                                             \
   protected:                                                                \
    std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {    \
47
      auto* op = new ::paddle::framework::OpDesc();                          \
D
dzhwinter 已提交
48 49 50 51 52 53 54 55 56 57
      op->SetType(#KERNEL_TYPE "_grad");                                     \
      op->SetInput("Out", Output("Out"));                                    \
      op->SetInput(::paddle::framework::GradVarName("Out"),                  \
                   OutputGrad("Out"));                                       \
                                                                             \
      op->SetAttrMap(Attrs());                                               \
                                                                             \
      op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X"));  \
      return std::unique_ptr<::paddle::framework::OpDesc>(op);               \
    }                                                                        \
D
dzhwinter 已提交
58
  }
D
dzhwinter 已提交
59

60 61 62 63
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
64
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
65 66 67 68 69
#ifdef PADDLE_WITH_MKLDNN
  auto it = oper.Attrs().find("use_mkldnn");
  if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
70
    layout = framework::DataLayout::kMKLDNN;
71 72 73 74 75 76 77
  }
#endif
  return framework::OpKernelType(
      framework::ToDataType(ctx.Input<framework::Tensor>(name)->type()),
      ctx.GetPlace(), layout, library);
}

Q
qijun 已提交
78 79 80 81
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

82
  void InferShape(framework::InferShapeContext* ctx) const override {
F
fengjiayi 已提交
83 84
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
85
  }
86

87
 protected:
88 89 90 91
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
92 93
};

Q
qijun 已提交
94 95 96 97
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

98
  void InferShape(framework::InferShapeContext* ctx) const override {
F
fengjiayi 已提交
99
    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
Q
qijun 已提交
100
  }
101

102
 protected:
103 104 105 106
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "Out");
  }
Q
qijun 已提交
107 108
};

D
dzhwinter 已提交
109
UNUSED constexpr char SigmoidDoc[] = R"DOC(
110
Sigmoid Activation Operator
K
Kexin Zhao 已提交
111

F
fengjiayi 已提交
112
$$out = \frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
113

D
dzhwinter 已提交
114
)DOC";
Q
qijun 已提交
115

D
dzhwinter 已提交
116
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
117
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
118

119
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
120

D
dzhwinter 已提交
121
)DOC";
122

D
dzhwinter 已提交
123
UNUSED constexpr char ExpDoc[] = R"DOC(
K
kexinzhao 已提交
124
Exp Activation Operator.
K
Kexin Zhao 已提交
125

F
fengjiayi 已提交
126
$out = e^x$
K
Kexin Zhao 已提交
127

D
dzhwinter 已提交
128
)DOC";
Q
qijun 已提交
129

D
dzhwinter 已提交
130
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
131
Relu Activation Operator.
K
Kexin Zhao 已提交
132

F
fengjiayi 已提交
133
$out = \max(x, 0)$
K
Kexin Zhao 已提交
134

D
dzhwinter 已提交
135
)DOC";
K
Kexin Zhao 已提交
136

D
dzhwinter 已提交
137
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
138
Tanh Activation Operator.
K
Kexin Zhao 已提交
139

Q
update  
qiaolongfei 已提交
140
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
141

D
dzhwinter 已提交
142
)DOC";
143

D
dzhwinter 已提交
144
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
145
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
146

Y
Yan Chunwei 已提交
147
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
148

D
dzhwinter 已提交
149
)DOC";
K
Kexin Zhao 已提交
150

D
dzhwinter 已提交
151
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
152
Sqrt Activation Operator.
K
Kexin Zhao 已提交
153

F
fengjiayi 已提交
154
$out = \sqrt{x}$
K
Kexin Zhao 已提交
155

D
dzhwinter 已提交
156
)DOC";
157

D
dzhwinter 已提交
158
UNUSED constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
159
Abs Activation Operator.
K
Kexin Zhao 已提交
160

F
fengjiayi 已提交
161
$out = |x|$
K
Kexin Zhao 已提交
162

D
dzhwinter 已提交
163
)DOC";
164

D
dzhwinter 已提交
165
UNUSED constexpr char CeilDoc[] = R"DOC(
D
dzhwinter 已提交
166 167
Ceil Activation Operator.

F
fengjiayi 已提交
168
$out = ceil(x)$
D
dzhwinter 已提交
169

D
dzhwinter 已提交
170
)DOC";
D
dzhwinter 已提交
171

D
dzhwinter 已提交
172
UNUSED constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
173 174
Floor Activation Operator.

F
fengjiayi 已提交
175
$out = floor(x)$
D
dzhwinter 已提交
176

D
dzhwinter 已提交
177
)DOC";
D
dzhwinter 已提交
178

D
dzhwinter 已提交
179
UNUSED constexpr char CosDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
180
Cosine Activation Operator.
C
add cos  
chengduoZH 已提交
181 182 183

$out = cos(x)$

D
dzhwinter 已提交
184
)DOC";
C
add cos  
chengduoZH 已提交
185

D
dzhwinter 已提交
186
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
187 188 189 190
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
191
)DOC";
C
add sin  
chengduoZH 已提交
192

D
dzhwinter 已提交
193
UNUSED constexpr char RoundDoc[] = R"DOC(
D
dzhwinter 已提交
194 195
Round Activation Operator.

F
fengjiayi 已提交
196
$out = [x]$
D
dzhwinter 已提交
197

D
dzhwinter 已提交
198
)DOC";
D
dzhwinter 已提交
199

D
dzhwinter 已提交
200
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
201
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
202

203
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
204

D
dzhwinter 已提交
205
)DOC";
206

D
dzhwinter 已提交
207
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
208
Log Activation Operator.
K
Kexin Zhao 已提交
209

F
fengjiayi 已提交
210
$out = \ln(x)$
K
Kexin Zhao 已提交
211 212 213

Natural logarithm of x.

D
dzhwinter 已提交
214 215
)DOC";

D
dzhwinter 已提交
216
UNUSED constexpr char SquareDoc[] = R"DOC(
D
dzhwinter 已提交
217 218 219
Square Activation Operator.

$out = x^2$
220

D
dzhwinter 已提交
221 222
)DOC";

D
dzhwinter 已提交
223
UNUSED constexpr char SoftplusDoc[] = R"DOC(
D
dzhwinter 已提交
224 225 226 227 228 229
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

D
dzhwinter 已提交
230
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
231 232 233 234 235 236 237
Softsign Activation Operator.

$$out = \frac{x}{1 + |x|}$$

)DOC";

class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
238
 public:
Y
Yu Yang 已提交
239
  void Make() override {
D
dzhwinter 已提交
240 241 242
    AddInput("X", "Input of LeakyRelu operator");
    AddOutput("Out", "Output of LeakyRelu operator");
    AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
K
Kexin Zhao 已提交
243
    AddComment(R"DOC(
D
dzhwinter 已提交
244
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
245

D
dzhwinter 已提交
246
$out = \max(x, \alpha * x)$
K
Kexin Zhao 已提交
247 248

)DOC");
249 250 251
  }
};

D
dzhwinter 已提交
252
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
253
 public:
Y
Yu Yang 已提交
254
  void Make() override {
D
dzhwinter 已提交
255 256 257
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
258
    AddComment(R"DOC(
259 260 261 262 263 264 265 266
:strong:`Softshrink Activation Operator`

..  math::
    out = \begin{cases} 
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
267 268

)DOC");
K
kexinzhao 已提交
269 270 271
  }
};

D
dzhwinter 已提交
272
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
273
 public:
Y
Yu Yang 已提交
274
  void Make() override {
D
dzhwinter 已提交
275 276
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
277 278
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
279
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
280
    AddComment(R"DOC(
Y
yuyang18 已提交
281
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
282

Y
yuyang18 已提交
283 284 285 286 287 288
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
289 290

)DOC");
291 292 293
  }
};

294 295
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
296
  void Make() override {
297
    AddInput("X", "Input of BRelu operator");
F
fengjiayi 已提交
298
    AddOutput("Out", "Output of BRelu operator");
299 300 301 302
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
303
    AddComment(R"DOC(
K
kexinzhao 已提交
304
BRelu Activation Operator.
K
Kexin Zhao 已提交
305

F
fengjiayi 已提交
306
$out = \max(\min(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
307 308

)DOC");
309 310 311 312 313
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
314
  void Make() override {
315
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
316
    AddOutput("Out", "Output of SoftRelu operator");
317 318
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
319
    AddComment(R"DOC(
K
kexinzhao 已提交
320
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
321

F
fengjiayi 已提交
322
$out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
K
Kexin Zhao 已提交
323 324

)DOC");
325 326 327
  }
};

328 329
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
330
  void Make() override {
K
Kexin Zhao 已提交
331
    AddInput("X", "Input of ELU operator");
F
fengjiayi 已提交
332
    AddOutput("Out", "Output of ELU operator");
333
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
334
    AddComment(R"DOC(
K
kexinzhao 已提交
335
ELU Activation Operator.
K
Kexin Zhao 已提交
336 337 338 339

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
340
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
341 342

)DOC");
343 344 345
  }
};

346 347
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
348
  void Make() override {
349
    AddInput("X", "Input of Relu6 operator");
F
fengjiayi 已提交
350
    AddOutput("Out", "Output of Relu6 operator");
351 352
    AddAttr<float>("threshold", "The threshold value of Relu6")
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
353
    AddComment(R"DOC(
K
kexinzhao 已提交
354
Relu6 Activation Operator.
K
Kexin Zhao 已提交
355

F
fengjiayi 已提交
356
$out = \min(\max(0, x), 6)$
K
Kexin Zhao 已提交
357 358

)DOC");
359 360 361
  }
};

362 363
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
364
  void Make() override {
365
    AddInput("X", "Input of Pow operator");
F
fengjiayi 已提交
366
    AddOutput("Out", "Output of Pow operator");
367
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
368
    AddComment(R"DOC(
K
kexinzhao 已提交
369
Pow Activation Operator.
K
Kexin Zhao 已提交
370

F
fengjiayi 已提交
371
$out = x^{factor}$
K
Kexin Zhao 已提交
372 373

)DOC");
374 375 376 377 378
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
379
  void Make() override {
380
    AddInput("X", "Input of STanh operator");
F
fengjiayi 已提交
381
    AddOutput("Out", "Output of STanh operator");
382 383 384 385
    AddAttr<float>("scale_a", "The scale parameter of a for the input")
        .SetDefault(2.0f / 3.0f);
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
386
    AddComment(R"DOC(
K
kexinzhao 已提交
387
STanh Activation Operator.
K
Kexin Zhao 已提交
388

Y
Yan Chunwei 已提交
389
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
390 391

)DOC");
Q
qijun 已提交
392 393 394
  }
};

395 396
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
397
  void Make() override {
398
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
399
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
400 401
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
402
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
403
    AddComment(R"DOC(
Y
yuyang18 已提交
404
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
405

Y
yuyang18 已提交
406
..  math::
K
Kexin Zhao 已提交
407

Y
yuyang18 已提交
408
    out = \begin{cases}
Y
yuyang18 已提交
409
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
410 411
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
412
)DOC");
413 414 415
  }
};

416 417
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
418
  void Make() override {
419
    AddInput("X", "Input of HardSigmoid operator");
F
fengjiayi 已提交
420
    AddOutput("Out", "Output of HardSigmoid operator");
421 422 423 424
    AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
        .SetDefault(0.2f);
    AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
        .SetDefault(0.5f);
425
    AddComment(R"DOC(
K
kexinzhao 已提交
426
HardSigmoid Activation Operator.
427

K
Kexin Zhao 已提交
428 429
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391), 
which is much faster than sigmoid.
430

F
fengjiayi 已提交
431
$out = \max(0, \min(1, slope * x + shift))$
432 433

The slope should be positive. The offset can be either positive or negative.
K
Kexin Zhao 已提交
434
The default slope and shift are set according to the above reference.
435 436
It is recommended to use the defaults for this activation.

K
Kexin Zhao 已提交
437
)DOC");
438 439 440
  }
};

A
Abhinav Arora 已提交
441 442
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
443
  void Make() override {
A
Abhinav Arora 已提交
444
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
445
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
446 447 448 449
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

F
fengjiayi 已提交
450
$$out = \\frac{x}{1 + e^{- \beta x}}$$
A
Abhinav Arora 已提交
451 452 453 454 455

)DOC");
  }
};

D
dzhwinter 已提交
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

D
dzhwinter 已提交
475 476
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu);
D
dzhwinter 已提交
477
REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp);
D
dzhwinter 已提交
478 479 480
REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Floor, floor);
D
dzhwinter 已提交
481
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sqrt, sqrt);
D
dzhwinter 已提交
482
REGISTER_ACTIVATION_OP_GRAD_MAKER(SoftRelu, soft_relu);
D
dzhwinter 已提交
483 484
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu6, relu6);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Reciprocal, reciprocal);
D
dzhwinter 已提交
485
REGISTER_ACTIVATION_OP_GRAD_MAKER(HardSigmoid, hard_sigmoid);
Q
qijun 已提交
486 487 488 489
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
490

D
dzhwinter 已提交
491
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
D
dzhwinter 已提交
492
  __macro(Sigmoid, sigmoid);                 \
493
  __macro(Relu, relu);                       \
D
dzhwinter 已提交
494
  __macro(Exp, exp);                         \
495
  __macro(Tanh, tanh);                       \
D
dzhwinter 已提交
496 497
  __macro(Ceil, ceil);                       \
  __macro(Floor, floor);                     \
498
  __macro(Sqrt, sqrt);                       \
D
dzhwinter 已提交
499 500 501 502
  __macro(SoftRelu, soft_relu);              \
  __macro(Relu6, relu6);                     \
  __macro(Reciprocal, reciprocal);           \
  __macro(HardSigmoid, hard_sigmoid);
D
dzhwinter 已提交
503 504

#define FOR_EACH_OP_FUNCTOR(__macro) \
D
dzhwinter 已提交
505 506
  __macro(LogSigmoid, logsigmoid);   \
  __macro(SoftShrink, softshrink);   \
507
  __macro(Abs, abs);                 \
D
dzhwinter 已提交
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
  __macro(Cos, cos);                 \
  __macro(Sin, sin);                 \
  __macro(Round, round);             \
  __macro(Log, log);                 \
  __macro(Square, square);           \
  __macro(BRelu, brelu);             \
  __macro(Pow, pow);                 \
  __macro(STanh, stanh);             \
  __macro(Softplus, softplus);       \
  __macro(Softsign, softsign);       \
  __macro(LeakyRelu, leaky_relu);    \
  __macro(TanhShrink, tanh_shrink);  \
  __macro(ELU, elu);                 \
  __macro(HardShrink, hard_shrink);  \
  __macro(Swish, swish);             \
  __macro(ThresholdedRelu, thresholded_relu);

#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE)        \
  REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
                    ::paddle::operators::OP_NAME##OpMaker,          \
                    ::paddle::operators::OP_NAME##GradMaker);       \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)

D
dzhwinter 已提交
531 532 533 534 535
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE)                    \
  REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp,     \
                    ::paddle::operators::OP_NAME##OpMaker,              \
                    ::paddle::framework::DefaultGradOpDescMaker<true>); \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
A
Abhinav Arora 已提交
536

Q
QI JUN 已提交
537 538 539 540 541 542 543 544 545 546 547
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor)   \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
548
                                ops::grad_functor<double>>);
549

D
dzhwinter 已提交
550
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP);
D
dzhwinter 已提交
551
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP);
552
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);