activation_op.cc 17.7 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
D
dzhwinter 已提交
16
#include <string>
K
Krzysztof Binias 已提交
17
#include "paddle/fluid/operators/mkldnn_activation_op.h"
D
dzhwinter 已提交
18
#include "paddle/fluid/platform/port.h"
Q
qijun 已提交
19 20 21 22

namespace paddle {
namespace operators {

23 24 25 26 27 28 29 30
using paddle::framework::Tensor;

#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)               \
  class OP_NAME##OpMaker                                                \
      : public ::paddle::framework::OpProtoAndCheckerMaker {            \
   public:                                                              \
    void Make() override {                                              \
      AddInput("X", "Input of " #OP_NAME " operator");                  \
31
      AddOutput("Out", "Output of " #OP_NAME " operator");              \
32 33 34 35 36
      AddAttr<bool>("use_mkldnn",                                       \
                    "(bool, default false) Only used in mkldnn kernel") \
          .SetDefault(false);                                           \
      AddComment(#OP_COMMENT);                                          \
    }                                                                   \
D
dzhwinter 已提交
37
  }
D
dzhwinter 已提交
38 39 40 41 42 43 44 45 46

#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE)              \
  class OP_NAME##GradMaker                                                   \
      : public ::paddle::framework::SingleGradOpDescMaker {                  \
   public:                                                                   \
    using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
                                                                             \
   protected:                                                                \
    std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {    \
47
      auto* op = new ::paddle::framework::OpDesc();                          \
D
dzhwinter 已提交
48 49 50 51 52 53 54 55 56 57
      op->SetType(#KERNEL_TYPE "_grad");                                     \
      op->SetInput("Out", Output("Out"));                                    \
      op->SetInput(::paddle::framework::GradVarName("Out"),                  \
                   OutputGrad("Out"));                                       \
                                                                             \
      op->SetAttrMap(Attrs());                                               \
                                                                             \
      op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X"));  \
      return std::unique_ptr<::paddle::framework::OpDesc>(op);               \
    }                                                                        \
D
dzhwinter 已提交
58
  }
D
dzhwinter 已提交
59

60 61 62 63
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
64
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
65 66 67 68 69
#ifdef PADDLE_WITH_MKLDNN
  auto it = oper.Attrs().find("use_mkldnn");
  if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
70
    layout = framework::DataLayout::kMKLDNN;
71 72 73 74 75 76 77
  }
#endif
  return framework::OpKernelType(
      framework::ToDataType(ctx.Input<framework::Tensor>(name)->type()),
      ctx.GetPlace(), layout, library);
}

Q
qijun 已提交
78 79 80 81
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

82
  void InferShape(framework::InferShapeContext* ctx) const override {
83
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
84
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
85
  }
86

87
 protected:
88 89 90 91
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
92 93
};

94 95 96 97 98 99 100 101 102 103 104 105 106
class ActivationOpInferVarType : public framework::VarTypeInference {
 public:
  void operator()(const framework::OpDesc& op_desc,
                  framework::BlockDesc* block) const override {
    auto x_name = op_desc.Input("X")[0];
    auto out_name = op_desc.Output("Out")[0];
    auto& x = block->FindRecursiveOrCreateVar(x_name);
    auto& out = block->FindRecursiveOrCreateVar(out_name);
    out.SetType(x.GetType());
    out.SetDataType(x.GetDataType());
  }
};

Q
qijun 已提交
107 108 109 110
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

111
  void InferShape(framework::InferShapeContext* ctx) const override {
112 113
    ctx->ShareDim("Out", framework::GradVarName("X"));
    ctx->ShareLoD("Out", framework::GradVarName("X"));
Q
qijun 已提交
114
  }
115

116
 protected:
117 118 119 120
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "Out");
  }
Q
qijun 已提交
121 122
};

D
dzhwinter 已提交
123
UNUSED constexpr char SigmoidDoc[] = R"DOC(
124
Sigmoid Activation Operator
K
Kexin Zhao 已提交
125

F
fengjiayi 已提交
126
$$out = \frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
127

D
dzhwinter 已提交
128
)DOC";
Q
qijun 已提交
129

D
dzhwinter 已提交
130
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
131
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
132

133
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
134

D
dzhwinter 已提交
135
)DOC";
136

D
dzhwinter 已提交
137
UNUSED constexpr char ExpDoc[] = R"DOC(
K
kexinzhao 已提交
138
Exp Activation Operator.
K
Kexin Zhao 已提交
139

F
fengjiayi 已提交
140
$out = e^x$
K
Kexin Zhao 已提交
141

D
dzhwinter 已提交
142
)DOC";
Q
qijun 已提交
143

D
dzhwinter 已提交
144
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
145
Relu Activation Operator.
K
Kexin Zhao 已提交
146

F
fengjiayi 已提交
147
$out = \max(x, 0)$
K
Kexin Zhao 已提交
148

D
dzhwinter 已提交
149
)DOC";
K
Kexin Zhao 已提交
150

D
dzhwinter 已提交
151
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
152
Tanh Activation Operator.
K
Kexin Zhao 已提交
153

Q
update  
qiaolongfei 已提交
154
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
155

D
dzhwinter 已提交
156
)DOC";
157

D
dzhwinter 已提交
158
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
159
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
160

Y
Yan Chunwei 已提交
161
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
162

D
dzhwinter 已提交
163
)DOC";
K
Kexin Zhao 已提交
164

D
dzhwinter 已提交
165
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
166
Sqrt Activation Operator.
K
Kexin Zhao 已提交
167

F
fengjiayi 已提交
168
$out = \sqrt{x}$
K
Kexin Zhao 已提交
169

D
dzhwinter 已提交
170
)DOC";
171

D
dzhwinter 已提交
172
UNUSED constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
173
Abs Activation Operator.
K
Kexin Zhao 已提交
174

F
fengjiayi 已提交
175
$out = |x|$
K
Kexin Zhao 已提交
176

D
dzhwinter 已提交
177
)DOC";
178

D
dzhwinter 已提交
179
UNUSED constexpr char CeilDoc[] = R"DOC(
D
dzhwinter 已提交
180 181
Ceil Activation Operator.

F
fengjiayi 已提交
182
$out = ceil(x)$
D
dzhwinter 已提交
183

D
dzhwinter 已提交
184
)DOC";
D
dzhwinter 已提交
185

D
dzhwinter 已提交
186
UNUSED constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
187 188
Floor Activation Operator.

F
fengjiayi 已提交
189
$out = floor(x)$
D
dzhwinter 已提交
190

D
dzhwinter 已提交
191
)DOC";
D
dzhwinter 已提交
192

D
dzhwinter 已提交
193
UNUSED constexpr char CosDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
194
Cosine Activation Operator.
C
add cos  
chengduoZH 已提交
195 196 197

$out = cos(x)$

D
dzhwinter 已提交
198
)DOC";
C
add cos  
chengduoZH 已提交
199

D
dzhwinter 已提交
200
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
201 202 203 204
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
205
)DOC";
C
add sin  
chengduoZH 已提交
206

D
dzhwinter 已提交
207
UNUSED constexpr char RoundDoc[] = R"DOC(
D
dzhwinter 已提交
208 209
Round Activation Operator.

F
fengjiayi 已提交
210
$out = [x]$
D
dzhwinter 已提交
211

D
dzhwinter 已提交
212
)DOC";
D
dzhwinter 已提交
213

D
dzhwinter 已提交
214
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
215
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
216

217
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
218

D
dzhwinter 已提交
219
)DOC";
220

D
dzhwinter 已提交
221
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
222
Log Activation Operator.
K
Kexin Zhao 已提交
223

F
fengjiayi 已提交
224
$out = \ln(x)$
K
Kexin Zhao 已提交
225 226 227

Natural logarithm of x.

D
dzhwinter 已提交
228 229
)DOC";

D
dzhwinter 已提交
230
UNUSED constexpr char SquareDoc[] = R"DOC(
D
dzhwinter 已提交
231 232 233
Square Activation Operator.

$out = x^2$
234

D
dzhwinter 已提交
235 236
)DOC";

D
dzhwinter 已提交
237
UNUSED constexpr char SoftplusDoc[] = R"DOC(
D
dzhwinter 已提交
238 239 240 241 242 243
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

D
dzhwinter 已提交
244
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
245 246 247 248 249 250 251
Softsign Activation Operator.

$$out = \frac{x}{1 + |x|}$$

)DOC";

class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
252
 public:
Y
Yu Yang 已提交
253
  void Make() override {
D
dzhwinter 已提交
254 255 256
    AddInput("X", "Input of LeakyRelu operator");
    AddOutput("Out", "Output of LeakyRelu operator");
    AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
K
Kexin Zhao 已提交
257
    AddComment(R"DOC(
D
dzhwinter 已提交
258
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
259

D
dzhwinter 已提交
260
$out = \max(x, \alpha * x)$
K
Kexin Zhao 已提交
261 262

)DOC");
263 264 265
  }
};

D
dzhwinter 已提交
266
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
267
 public:
Y
Yu Yang 已提交
268
  void Make() override {
D
dzhwinter 已提交
269 270 271
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
272
    AddComment(R"DOC(
273 274 275 276 277 278 279 280
:strong:`Softshrink Activation Operator`

..  math::
    out = \begin{cases} 
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
281 282

)DOC");
K
kexinzhao 已提交
283 284 285
  }
};

D
dzhwinter 已提交
286
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
287
 public:
Y
Yu Yang 已提交
288
  void Make() override {
D
dzhwinter 已提交
289 290
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
291 292
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
293
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
294
    AddComment(R"DOC(
Y
yuyang18 已提交
295
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
296

Y
yuyang18 已提交
297 298 299 300 301 302
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
303 304

)DOC");
305 306 307
  }
};

308 309
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
310
  void Make() override {
311
    AddInput("X", "Input of BRelu operator");
F
fengjiayi 已提交
312
    AddOutput("Out", "Output of BRelu operator");
313 314 315 316
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
317
    AddComment(R"DOC(
K
kexinzhao 已提交
318
BRelu Activation Operator.
K
Kexin Zhao 已提交
319

F
fengjiayi 已提交
320
$out = \max(\min(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
321 322

)DOC");
323 324 325 326 327
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
328
  void Make() override {
329
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
330
    AddOutput("Out", "Output of SoftRelu operator");
331 332
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
333
    AddComment(R"DOC(
K
kexinzhao 已提交
334
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
335

F
fengjiayi 已提交
336
$out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
K
Kexin Zhao 已提交
337 338

)DOC");
339 340 341
  }
};

342 343
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
344
  void Make() override {
K
Kexin Zhao 已提交
345
    AddInput("X", "Input of ELU operator");
F
fengjiayi 已提交
346
    AddOutput("Out", "Output of ELU operator");
347
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
348
    AddComment(R"DOC(
K
kexinzhao 已提交
349
ELU Activation Operator.
K
Kexin Zhao 已提交
350 351 352 353

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
354
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
355 356

)DOC");
357 358 359
  }
};

360 361
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
362
  void Make() override {
363
    AddInput("X", "Input of Relu6 operator");
F
fengjiayi 已提交
364
    AddOutput("Out", "Output of Relu6 operator");
365 366
    AddAttr<float>("threshold", "The threshold value of Relu6")
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
367
    AddComment(R"DOC(
K
kexinzhao 已提交
368
Relu6 Activation Operator.
K
Kexin Zhao 已提交
369

F
fengjiayi 已提交
370
$out = \min(\max(0, x), 6)$
K
Kexin Zhao 已提交
371 372

)DOC");
373 374 375
  }
};

376 377
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
378
  void Make() override {
379
    AddInput("X", "Input of Pow operator");
F
fengjiayi 已提交
380
    AddOutput("Out", "Output of Pow operator");
381
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
382
    AddComment(R"DOC(
K
kexinzhao 已提交
383
Pow Activation Operator.
K
Kexin Zhao 已提交
384

F
fengjiayi 已提交
385
$out = x^{factor}$
K
Kexin Zhao 已提交
386 387

)DOC");
388 389 390 391 392
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
393
  void Make() override {
394
    AddInput("X", "Input of STanh operator");
F
fengjiayi 已提交
395
    AddOutput("Out", "Output of STanh operator");
396 397 398 399
    AddAttr<float>("scale_a", "The scale parameter of a for the input")
        .SetDefault(2.0f / 3.0f);
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
400
    AddComment(R"DOC(
K
kexinzhao 已提交
401
STanh Activation Operator.
K
Kexin Zhao 已提交
402

Y
Yan Chunwei 已提交
403
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
404 405

)DOC");
Q
qijun 已提交
406 407 408
  }
};

409 410
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
411
  void Make() override {
412
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
413
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
414 415
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
416
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
417
    AddComment(R"DOC(
Y
yuyang18 已提交
418
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
419

Y
yuyang18 已提交
420
..  math::
K
Kexin Zhao 已提交
421

Y
yuyang18 已提交
422
    out = \begin{cases}
Y
yuyang18 已提交
423
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
424 425
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
426
)DOC");
427 428 429
  }
};

430 431
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
432
  void Make() override {
433
    AddInput("X", "Input of HardSigmoid operator");
F
fengjiayi 已提交
434
    AddOutput("Out", "Output of HardSigmoid operator");
435 436 437 438
    AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
        .SetDefault(0.2f);
    AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
        .SetDefault(0.5f);
439
    AddComment(R"DOC(
K
kexinzhao 已提交
440
HardSigmoid Activation Operator.
441

K
Kexin Zhao 已提交
442 443
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391), 
which is much faster than sigmoid.
444

F
fengjiayi 已提交
445
$out = \max(0, \min(1, slope * x + shift))$
446 447

The slope should be positive. The offset can be either positive or negative.
K
Kexin Zhao 已提交
448
The default slope and shift are set according to the above reference.
449 450
It is recommended to use the defaults for this activation.

K
Kexin Zhao 已提交
451
)DOC");
452 453 454
  }
};

A
Abhinav Arora 已提交
455 456
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
457
  void Make() override {
A
Abhinav Arora 已提交
458
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
459
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
460 461 462 463
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

F
fengjiayi 已提交
464
$$out = \\frac{x}{1 + e^{- \beta x}}$$
A
Abhinav Arora 已提交
465 466 467 468 469

)DOC");
  }
};

D
dzhwinter 已提交
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

D
dzhwinter 已提交
489 490
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu);
D
dzhwinter 已提交
491
REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp);
D
dzhwinter 已提交
492 493 494
REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Floor, floor);
D
dzhwinter 已提交
495
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sqrt, sqrt);
D
dzhwinter 已提交
496
REGISTER_ACTIVATION_OP_GRAD_MAKER(SoftRelu, soft_relu);
D
dzhwinter 已提交
497 498
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu6, relu6);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Reciprocal, reciprocal);
D
dzhwinter 已提交
499
REGISTER_ACTIVATION_OP_GRAD_MAKER(HardSigmoid, hard_sigmoid);
Q
qijun 已提交
500 501 502 503
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
504

D
dzhwinter 已提交
505
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
D
dzhwinter 已提交
506
  __macro(Sigmoid, sigmoid);                 \
507
  __macro(Relu, relu);                       \
D
dzhwinter 已提交
508
  __macro(Exp, exp);                         \
509
  __macro(Tanh, tanh);                       \
D
dzhwinter 已提交
510 511
  __macro(Ceil, ceil);                       \
  __macro(Floor, floor);                     \
512
  __macro(Sqrt, sqrt);                       \
D
dzhwinter 已提交
513 514 515 516
  __macro(SoftRelu, soft_relu);              \
  __macro(Relu6, relu6);                     \
  __macro(Reciprocal, reciprocal);           \
  __macro(HardSigmoid, hard_sigmoid);
D
dzhwinter 已提交
517 518

#define FOR_EACH_OP_FUNCTOR(__macro) \
D
dzhwinter 已提交
519 520
  __macro(LogSigmoid, logsigmoid);   \
  __macro(SoftShrink, softshrink);   \
521
  __macro(Abs, abs);                 \
D
dzhwinter 已提交
522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541
  __macro(Cos, cos);                 \
  __macro(Sin, sin);                 \
  __macro(Round, round);             \
  __macro(Log, log);                 \
  __macro(Square, square);           \
  __macro(BRelu, brelu);             \
  __macro(Pow, pow);                 \
  __macro(STanh, stanh);             \
  __macro(Softplus, softplus);       \
  __macro(Softsign, softsign);       \
  __macro(LeakyRelu, leaky_relu);    \
  __macro(TanhShrink, tanh_shrink);  \
  __macro(ELU, elu);                 \
  __macro(HardShrink, hard_shrink);  \
  __macro(Swish, swish);             \
  __macro(ThresholdedRelu, thresholded_relu);

#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE)        \
  REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
                    ::paddle::operators::OP_NAME##OpMaker,          \
542
                    ::paddle::operators::ActivationOpInferVarType,  \
D
dzhwinter 已提交
543 544 545
                    ::paddle::operators::OP_NAME##GradMaker);       \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)

D
dzhwinter 已提交
546 547 548
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE)                    \
  REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp,     \
                    ::paddle::operators::OP_NAME##OpMaker,              \
549
                    ::paddle::operators::ActivationOpInferVarType,      \
D
dzhwinter 已提交
550 551
                    ::paddle::framework::DefaultGradOpDescMaker<true>); \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
A
Abhinav Arora 已提交
552

Q
QI JUN 已提交
553 554 555 556 557 558 559 560 561 562 563
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor)   \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
564
                                ops::grad_functor<double>>);
565

D
dzhwinter 已提交
566
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP);
D
dzhwinter 已提交
567
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP);
568
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);