activation_op.cc 18.0 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
D
dzhwinter 已提交
16
#include <string>
K
Krzysztof Binias 已提交
17
#include "paddle/fluid/operators/mkldnn_activation_op.h"
Q
qijun 已提交
18 19 20 21

namespace paddle {
namespace operators {

22 23 24 25 26 27 28 29 30 31 32 33 34 35
using paddle::framework::Tensor;

#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)               \
  class OP_NAME##OpMaker                                                \
      : public ::paddle::framework::OpProtoAndCheckerMaker {            \
   public:                                                              \
    void Make() override {                                              \
      AddInput("X", "Input of " #OP_NAME " operator");                  \
      AddOutput("Out", "Output of " #OP_NAME " operator").Reuse("X");   \
      AddAttr<bool>("use_mkldnn",                                       \
                    "(bool, default false) Only used in mkldnn kernel") \
          .SetDefault(false);                                           \
      AddComment(#OP_COMMENT);                                          \
    }                                                                   \
D
dzhwinter 已提交
36
  }
D
dzhwinter 已提交
37 38 39 40 41 42 43 44 45

#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE)              \
  class OP_NAME##GradMaker                                                   \
      : public ::paddle::framework::SingleGradOpDescMaker {                  \
   public:                                                                   \
    using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
                                                                             \
   protected:                                                                \
    std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {    \
46
      auto* op = new ::paddle::framework::OpDesc();                          \
D
dzhwinter 已提交
47 48 49 50 51 52 53 54 55 56
      op->SetType(#KERNEL_TYPE "_grad");                                     \
      op->SetInput("Out", Output("Out"));                                    \
      op->SetInput(::paddle::framework::GradVarName("Out"),                  \
                   OutputGrad("Out"));                                       \
                                                                             \
      op->SetAttrMap(Attrs());                                               \
                                                                             \
      op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X"));  \
      return std::unique_ptr<::paddle::framework::OpDesc>(op);               \
    }                                                                        \
D
dzhwinter 已提交
57
  }
D
dzhwinter 已提交
58

59 60 61 62
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
63
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
64 65 66 67 68
#ifdef PADDLE_WITH_MKLDNN
  auto it = oper.Attrs().find("use_mkldnn");
  if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
69
    layout = framework::DataLayout::kMKLDNN;
70 71 72 73 74 75 76
  }
#endif
  return framework::OpKernelType(
      framework::ToDataType(ctx.Input<framework::Tensor>(name)->type()),
      ctx.GetPlace(), layout, library);
}

Q
qijun 已提交
77 78 79 80
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

81
  void InferShape(framework::InferShapeContext* ctx) const override {
82
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
83
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
84
  }
85

86
 protected:
87 88 89 90
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
91 92
};

93 94 95 96 97 98 99 100 101 102 103 104 105
class ActivationOpInferVarType : public framework::VarTypeInference {
 public:
  void operator()(const framework::OpDesc& op_desc,
                  framework::BlockDesc* block) const override {
    auto x_name = op_desc.Input("X")[0];
    auto out_name = op_desc.Output("Out")[0];
    auto& x = block->FindRecursiveOrCreateVar(x_name);
    auto& out = block->FindRecursiveOrCreateVar(out_name);
    out.SetType(x.GetType());
    out.SetDataType(x.GetDataType());
  }
};

Q
qijun 已提交
106 107 108 109
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

110
  void InferShape(framework::InferShapeContext* ctx) const override {
111 112
    ctx->ShareDim("Out", framework::GradVarName("X"));
    ctx->ShareLoD("Out", framework::GradVarName("X"));
Q
qijun 已提交
113
  }
114

115
 protected:
116 117 118 119
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "Out");
  }
Q
qijun 已提交
120 121
};

Q
qiaolongfei 已提交
122
__attribute__((unused)) constexpr char SigmoidDoc[] = R"DOC(
123
Sigmoid Activation Operator
K
Kexin Zhao 已提交
124

F
fengjiayi 已提交
125
$$out = \frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
126

D
dzhwinter 已提交
127
)DOC";
Q
qijun 已提交
128

Q
qiaolongfei 已提交
129
__attribute__((unused)) constexpr char LogSigmoidDoc[] = R"DOC(
130
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
131

132
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
133

D
dzhwinter 已提交
134
)DOC";
135

Q
qiaolongfei 已提交
136
__attribute__((unused)) constexpr char ExpDoc[] = R"DOC(
K
kexinzhao 已提交
137
Exp Activation Operator.
K
Kexin Zhao 已提交
138

F
fengjiayi 已提交
139
$out = e^x$
K
Kexin Zhao 已提交
140

D
dzhwinter 已提交
141
)DOC";
Q
qijun 已提交
142

Q
qiaolongfei 已提交
143
__attribute__((unused)) constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
144
Relu Activation Operator.
K
Kexin Zhao 已提交
145

F
fengjiayi 已提交
146
$out = \max(x, 0)$
K
Kexin Zhao 已提交
147

D
dzhwinter 已提交
148
)DOC";
K
Kexin Zhao 已提交
149

Q
qiaolongfei 已提交
150
__attribute__((unused)) constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
151
Tanh Activation Operator.
K
Kexin Zhao 已提交
152

Q
update  
qiaolongfei 已提交
153
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
154

D
dzhwinter 已提交
155
)DOC";
156

Q
qiaolongfei 已提交
157
__attribute__((unused)) constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
158
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
159

Y
Yan Chunwei 已提交
160
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
161

D
dzhwinter 已提交
162
)DOC";
K
Kexin Zhao 已提交
163

Q
qiaolongfei 已提交
164
__attribute__((unused)) constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
165
Sqrt Activation Operator.
K
Kexin Zhao 已提交
166

F
fengjiayi 已提交
167
$out = \sqrt{x}$
K
Kexin Zhao 已提交
168

D
dzhwinter 已提交
169
)DOC";
170

Q
qiaolongfei 已提交
171
__attribute__((unused)) constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
172
Abs Activation Operator.
K
Kexin Zhao 已提交
173

F
fengjiayi 已提交
174
$out = |x|$
K
Kexin Zhao 已提交
175

D
dzhwinter 已提交
176
)DOC";
177

Q
qiaolongfei 已提交
178
__attribute__((unused)) constexpr char CeilDoc[] = R"DOC(
D
dzhwinter 已提交
179 180
Ceil Activation Operator.

F
fengjiayi 已提交
181
$out = ceil(x)$
D
dzhwinter 已提交
182

D
dzhwinter 已提交
183
)DOC";
D
dzhwinter 已提交
184

Q
qiaolongfei 已提交
185
__attribute__((unused)) constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
186 187
Floor Activation Operator.

F
fengjiayi 已提交
188
$out = floor(x)$
D
dzhwinter 已提交
189

D
dzhwinter 已提交
190
)DOC";
D
dzhwinter 已提交
191

Q
qiaolongfei 已提交
192
__attribute__((unused)) constexpr char CosDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
193
Cosine Activation Operator.
C
add cos  
chengduoZH 已提交
194 195 196

$out = cos(x)$

D
dzhwinter 已提交
197
)DOC";
C
add cos  
chengduoZH 已提交
198

Q
qiaolongfei 已提交
199
__attribute__((unused)) constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
200 201 202 203
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
204
)DOC";
C
add sin  
chengduoZH 已提交
205

Q
qiaolongfei 已提交
206
__attribute__((unused)) constexpr char RoundDoc[] = R"DOC(
D
dzhwinter 已提交
207 208
Round Activation Operator.

F
fengjiayi 已提交
209
$out = [x]$
D
dzhwinter 已提交
210

D
dzhwinter 已提交
211
)DOC";
D
dzhwinter 已提交
212

Q
qiaolongfei 已提交
213
__attribute__((unused)) constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
214
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
215

216
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
217

D
dzhwinter 已提交
218
)DOC";
219

Q
qiaolongfei 已提交
220
__attribute__((unused)) constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
221
Log Activation Operator.
K
Kexin Zhao 已提交
222

F
fengjiayi 已提交
223
$out = \ln(x)$
K
Kexin Zhao 已提交
224 225 226

Natural logarithm of x.

D
dzhwinter 已提交
227 228
)DOC";

Q
qiaolongfei 已提交
229
__attribute__((unused)) constexpr char SquareDoc[] = R"DOC(
D
dzhwinter 已提交
230 231 232
Square Activation Operator.

$out = x^2$
233

D
dzhwinter 已提交
234 235
)DOC";

Q
qiaolongfei 已提交
236
__attribute__((unused)) constexpr char SoftplusDoc[] = R"DOC(
D
dzhwinter 已提交
237 238 239 240 241 242
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

Q
qiaolongfei 已提交
243
__attribute__((unused)) constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
244 245 246 247 248 249 250
Softsign Activation Operator.

$$out = \frac{x}{1 + |x|}$$

)DOC";

class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
251
 public:
Y
Yu Yang 已提交
252
  void Make() override {
D
dzhwinter 已提交
253 254 255
    AddInput("X", "Input of LeakyRelu operator");
    AddOutput("Out", "Output of LeakyRelu operator");
    AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
K
Kexin Zhao 已提交
256
    AddComment(R"DOC(
D
dzhwinter 已提交
257
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
258

D
dzhwinter 已提交
259
$out = \max(x, \alpha * x)$
K
Kexin Zhao 已提交
260 261

)DOC");
262 263 264
  }
};

D
dzhwinter 已提交
265
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
266
 public:
Y
Yu Yang 已提交
267
  void Make() override {
D
dzhwinter 已提交
268 269 270
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
271
    AddComment(R"DOC(
272 273 274 275 276 277 278 279
:strong:`Softshrink Activation Operator`

..  math::
    out = \begin{cases} 
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
280 281

)DOC");
K
kexinzhao 已提交
282 283 284
  }
};

D
dzhwinter 已提交
285
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
286
 public:
Y
Yu Yang 已提交
287
  void Make() override {
D
dzhwinter 已提交
288 289
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
290 291
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
292
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
293
    AddComment(R"DOC(
Y
yuyang18 已提交
294
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
295

Y
yuyang18 已提交
296 297 298 299 300 301
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
302 303

)DOC");
304 305 306
  }
};

307 308
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
309
  void Make() override {
310
    AddInput("X", "Input of BRelu operator");
F
fengjiayi 已提交
311
    AddOutput("Out", "Output of BRelu operator");
312 313 314 315
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
316
    AddComment(R"DOC(
K
kexinzhao 已提交
317
BRelu Activation Operator.
K
Kexin Zhao 已提交
318

F
fengjiayi 已提交
319
$out = \max(\min(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
320 321

)DOC");
322 323 324 325 326
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
327
  void Make() override {
328
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
329
    AddOutput("Out", "Output of SoftRelu operator");
330 331
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
332
    AddComment(R"DOC(
K
kexinzhao 已提交
333
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
334

F
fengjiayi 已提交
335
$out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
K
Kexin Zhao 已提交
336 337

)DOC");
338 339 340
  }
};

341 342
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
343
  void Make() override {
K
Kexin Zhao 已提交
344
    AddInput("X", "Input of ELU operator");
F
fengjiayi 已提交
345
    AddOutput("Out", "Output of ELU operator");
346
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
347
    AddComment(R"DOC(
K
kexinzhao 已提交
348
ELU Activation Operator.
K
Kexin Zhao 已提交
349 350 351 352

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
353
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
354 355

)DOC");
356 357 358
  }
};

359 360
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
361
  void Make() override {
362
    AddInput("X", "Input of Relu6 operator");
F
fengjiayi 已提交
363
    AddOutput("Out", "Output of Relu6 operator");
364 365
    AddAttr<float>("threshold", "The threshold value of Relu6")
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
366
    AddComment(R"DOC(
K
kexinzhao 已提交
367
Relu6 Activation Operator.
K
Kexin Zhao 已提交
368

F
fengjiayi 已提交
369
$out = \min(\max(0, x), 6)$
K
Kexin Zhao 已提交
370 371

)DOC");
372 373 374
  }
};

375 376
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
377
  void Make() override {
378
    AddInput("X", "Input of Pow operator");
F
fengjiayi 已提交
379
    AddOutput("Out", "Output of Pow operator");
380
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
381
    AddComment(R"DOC(
K
kexinzhao 已提交
382
Pow Activation Operator.
K
Kexin Zhao 已提交
383

F
fengjiayi 已提交
384
$out = x^{factor}$
K
Kexin Zhao 已提交
385 386

)DOC");
387 388 389 390 391
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
392
  void Make() override {
393
    AddInput("X", "Input of STanh operator");
F
fengjiayi 已提交
394
    AddOutput("Out", "Output of STanh operator");
395 396 397 398
    AddAttr<float>("scale_a", "The scale parameter of a for the input")
        .SetDefault(2.0f / 3.0f);
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
399
    AddComment(R"DOC(
K
kexinzhao 已提交
400
STanh Activation Operator.
K
Kexin Zhao 已提交
401

Y
Yan Chunwei 已提交
402
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
403 404

)DOC");
Q
qijun 已提交
405 406 407
  }
};

408 409
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
410
  void Make() override {
411
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
412
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
413 414
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
415
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
416
    AddComment(R"DOC(
Y
yuyang18 已提交
417
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
418

Y
yuyang18 已提交
419
..  math::
K
Kexin Zhao 已提交
420

Y
yuyang18 已提交
421
    out = \begin{cases}
Y
yuyang18 已提交
422
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
423 424
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
425
)DOC");
426 427 428
  }
};

429 430
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
431
  void Make() override {
432
    AddInput("X", "Input of HardSigmoid operator");
F
fengjiayi 已提交
433
    AddOutput("Out", "Output of HardSigmoid operator");
434 435 436 437
    AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
        .SetDefault(0.2f);
    AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
        .SetDefault(0.5f);
438
    AddComment(R"DOC(
K
kexinzhao 已提交
439
HardSigmoid Activation Operator.
440

K
Kexin Zhao 已提交
441 442
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391), 
which is much faster than sigmoid.
443

F
fengjiayi 已提交
444
$out = \max(0, \min(1, slope * x + shift))$
445 446

The slope should be positive. The offset can be either positive or negative.
K
Kexin Zhao 已提交
447
The default slope and shift are set according to the above reference.
448 449
It is recommended to use the defaults for this activation.

K
Kexin Zhao 已提交
450
)DOC");
451 452 453
  }
};

A
Abhinav Arora 已提交
454 455
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
456
  void Make() override {
A
Abhinav Arora 已提交
457
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
458
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
459 460 461 462
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

F
fengjiayi 已提交
463
$$out = \\frac{x}{1 + e^{- \beta x}}$$
A
Abhinav Arora 已提交
464 465 466 467 468

)DOC");
  }
};

D
dzhwinter 已提交
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

D
dzhwinter 已提交
488 489
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu);
D
dzhwinter 已提交
490
REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp);
D
dzhwinter 已提交
491 492 493
REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Floor, floor);
D
dzhwinter 已提交
494
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sqrt, sqrt);
D
dzhwinter 已提交
495
REGISTER_ACTIVATION_OP_GRAD_MAKER(SoftRelu, soft_relu);
D
dzhwinter 已提交
496 497
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu6, relu6);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Reciprocal, reciprocal);
D
dzhwinter 已提交
498
REGISTER_ACTIVATION_OP_GRAD_MAKER(HardSigmoid, hard_sigmoid);
Q
qijun 已提交
499 500 501 502
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
503

D
dzhwinter 已提交
504
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
D
dzhwinter 已提交
505
  __macro(Sigmoid, sigmoid);                 \
506
  __macro(Relu, relu);                       \
D
dzhwinter 已提交
507
  __macro(Exp, exp);                         \
508
  __macro(Tanh, tanh);                       \
D
dzhwinter 已提交
509 510
  __macro(Ceil, ceil);                       \
  __macro(Floor, floor);                     \
511
  __macro(Sqrt, sqrt);                       \
D
dzhwinter 已提交
512 513 514 515
  __macro(SoftRelu, soft_relu);              \
  __macro(Relu6, relu6);                     \
  __macro(Reciprocal, reciprocal);           \
  __macro(HardSigmoid, hard_sigmoid);
D
dzhwinter 已提交
516 517

#define FOR_EACH_OP_FUNCTOR(__macro) \
D
dzhwinter 已提交
518 519
  __macro(LogSigmoid, logsigmoid);   \
  __macro(SoftShrink, softshrink);   \
520
  __macro(Abs, abs);                 \
D
dzhwinter 已提交
521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540
  __macro(Cos, cos);                 \
  __macro(Sin, sin);                 \
  __macro(Round, round);             \
  __macro(Log, log);                 \
  __macro(Square, square);           \
  __macro(BRelu, brelu);             \
  __macro(Pow, pow);                 \
  __macro(STanh, stanh);             \
  __macro(Softplus, softplus);       \
  __macro(Softsign, softsign);       \
  __macro(LeakyRelu, leaky_relu);    \
  __macro(TanhShrink, tanh_shrink);  \
  __macro(ELU, elu);                 \
  __macro(HardShrink, hard_shrink);  \
  __macro(Swish, swish);             \
  __macro(ThresholdedRelu, thresholded_relu);

#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE)        \
  REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
                    ::paddle::operators::OP_NAME##OpMaker,          \
541
                    ::paddle::operators::ActivationOpInferVarType,  \
D
dzhwinter 已提交
542 543 544
                    ::paddle::operators::OP_NAME##GradMaker);       \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)

D
dzhwinter 已提交
545 546 547
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE)                    \
  REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp,     \
                    ::paddle::operators::OP_NAME##OpMaker,              \
548
                    ::paddle::operators::ActivationOpInferVarType,      \
D
dzhwinter 已提交
549 550
                    ::paddle::framework::DefaultGradOpDescMaker<true>); \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
A
Abhinav Arora 已提交
551

Q
QI JUN 已提交
552 553 554 555 556 557 558 559 560 561 562
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor)   \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
563
                                ops::grad_functor<double>>);
564

D
dzhwinter 已提交
565
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP);
D
dzhwinter 已提交
566
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP);
567
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);