activation_op.cc 17.6 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
D
dzhwinter 已提交
16
#include <string>
K
Krzysztof Binias 已提交
17
#include "paddle/fluid/operators/mkldnn_activation_op.h"
D
dzhwinter 已提交
18
#include "paddle/fluid/platform/port.h"
Q
qijun 已提交
19 20 21 22

namespace paddle {
namespace operators {

23 24 25 26 27 28 29 30
using paddle::framework::Tensor;

#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)               \
  class OP_NAME##OpMaker                                                \
      : public ::paddle::framework::OpProtoAndCheckerMaker {            \
   public:                                                              \
    void Make() override {                                              \
      AddInput("X", "Input of " #OP_NAME " operator");                  \
31
      AddOutput("Out", "Output of " #OP_NAME " operator");              \
32 33 34 35 36
      AddAttr<bool>("use_mkldnn",                                       \
                    "(bool, default false) Only used in mkldnn kernel") \
          .SetDefault(false);                                           \
      AddComment(#OP_COMMENT);                                          \
    }                                                                   \
D
dzhwinter 已提交
37
  }
D
dzhwinter 已提交
38 39 40 41 42 43 44 45 46

#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE)              \
  class OP_NAME##GradMaker                                                   \
      : public ::paddle::framework::SingleGradOpDescMaker {                  \
   public:                                                                   \
    using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
                                                                             \
   protected:                                                                \
    std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {    \
47
      auto* op = new ::paddle::framework::OpDesc();                          \
D
dzhwinter 已提交
48 49 50 51 52 53 54 55 56 57
      op->SetType(#KERNEL_TYPE "_grad");                                     \
      op->SetInput("Out", Output("Out"));                                    \
      op->SetInput(::paddle::framework::GradVarName("Out"),                  \
                   OutputGrad("Out"));                                       \
                                                                             \
      op->SetAttrMap(Attrs());                                               \
                                                                             \
      op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X"));  \
      return std::unique_ptr<::paddle::framework::OpDesc>(op);               \
    }                                                                        \
D
dzhwinter 已提交
58
  }
D
dzhwinter 已提交
59

60 61 62 63
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
64
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
65 66 67 68 69
#ifdef PADDLE_WITH_MKLDNN
  auto it = oper.Attrs().find("use_mkldnn");
  if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
70
    layout = framework::DataLayout::kMKLDNN;
71 72 73 74 75 76 77
  }
#endif
  return framework::OpKernelType(
      framework::ToDataType(ctx.Input<framework::Tensor>(name)->type()),
      ctx.GetPlace(), layout, library);
}

Q
qijun 已提交
78 79 80 81
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

82
  void InferShape(framework::InferShapeContext* ctx) const override {
83
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
84
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
85
  }
86

87
 protected:
88 89 90 91
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
92 93
};

C
chengduo 已提交
94 95 96 97 98 99
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
  std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
      const override {
    return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
100 101 102
  }
};

Q
qijun 已提交
103 104 105 106
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

107
  void InferShape(framework::InferShapeContext* ctx) const override {
108 109
    ctx->ShareDim("Out", framework::GradVarName("X"));
    ctx->ShareLoD("Out", framework::GradVarName("X"));
Q
qijun 已提交
110
  }
111

112
 protected:
113 114 115 116
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "Out");
  }
Q
qijun 已提交
117 118
};

D
dzhwinter 已提交
119
UNUSED constexpr char SigmoidDoc[] = R"DOC(
120
Sigmoid Activation Operator
K
Kexin Zhao 已提交
121

F
fengjiayi 已提交
122
$$out = \frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
123

D
dzhwinter 已提交
124
)DOC";
Q
qijun 已提交
125

D
dzhwinter 已提交
126
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
127
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
128

129
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
130

D
dzhwinter 已提交
131
)DOC";
132

D
dzhwinter 已提交
133
UNUSED constexpr char ExpDoc[] = R"DOC(
K
kexinzhao 已提交
134
Exp Activation Operator.
K
Kexin Zhao 已提交
135

F
fengjiayi 已提交
136
$out = e^x$
K
Kexin Zhao 已提交
137

D
dzhwinter 已提交
138
)DOC";
Q
qijun 已提交
139

D
dzhwinter 已提交
140
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
141
Relu Activation Operator.
K
Kexin Zhao 已提交
142

F
fengjiayi 已提交
143
$out = \max(x, 0)$
K
Kexin Zhao 已提交
144

D
dzhwinter 已提交
145
)DOC";
K
Kexin Zhao 已提交
146

D
dzhwinter 已提交
147
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
148
Tanh Activation Operator.
K
Kexin Zhao 已提交
149

Q
update  
qiaolongfei 已提交
150
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
151

D
dzhwinter 已提交
152
)DOC";
153

D
dzhwinter 已提交
154
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
155
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
156

Y
Yan Chunwei 已提交
157
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
158

D
dzhwinter 已提交
159
)DOC";
K
Kexin Zhao 已提交
160

D
dzhwinter 已提交
161
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
162
Sqrt Activation Operator.
K
Kexin Zhao 已提交
163

F
fengjiayi 已提交
164
$out = \sqrt{x}$
K
Kexin Zhao 已提交
165

D
dzhwinter 已提交
166
)DOC";
167

D
dzhwinter 已提交
168
UNUSED constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
169
Abs Activation Operator.
K
Kexin Zhao 已提交
170

F
fengjiayi 已提交
171
$out = |x|$
K
Kexin Zhao 已提交
172

D
dzhwinter 已提交
173
)DOC";
174

D
dzhwinter 已提交
175
UNUSED constexpr char CeilDoc[] = R"DOC(
D
dzhwinter 已提交
176 177
Ceil Activation Operator.

F
fengjiayi 已提交
178
$out = ceil(x)$
D
dzhwinter 已提交
179

D
dzhwinter 已提交
180
)DOC";
D
dzhwinter 已提交
181

D
dzhwinter 已提交
182
UNUSED constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
183 184
Floor Activation Operator.

F
fengjiayi 已提交
185
$out = floor(x)$
D
dzhwinter 已提交
186

D
dzhwinter 已提交
187
)DOC";
D
dzhwinter 已提交
188

D
dzhwinter 已提交
189
UNUSED constexpr char CosDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
190
Cosine Activation Operator.
C
add cos  
chengduoZH 已提交
191 192 193

$out = cos(x)$

D
dzhwinter 已提交
194
)DOC";
C
add cos  
chengduoZH 已提交
195

D
dzhwinter 已提交
196
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
197 198 199 200
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
201
)DOC";
C
add sin  
chengduoZH 已提交
202

D
dzhwinter 已提交
203
UNUSED constexpr char RoundDoc[] = R"DOC(
D
dzhwinter 已提交
204 205
Round Activation Operator.

F
fengjiayi 已提交
206
$out = [x]$
D
dzhwinter 已提交
207

D
dzhwinter 已提交
208
)DOC";
D
dzhwinter 已提交
209

D
dzhwinter 已提交
210
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
211
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
212

213
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
214

D
dzhwinter 已提交
215
)DOC";
216

D
dzhwinter 已提交
217
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
218
Log Activation Operator.
K
Kexin Zhao 已提交
219

F
fengjiayi 已提交
220
$out = \ln(x)$
K
Kexin Zhao 已提交
221 222 223

Natural logarithm of x.

D
dzhwinter 已提交
224 225
)DOC";

D
dzhwinter 已提交
226
UNUSED constexpr char SquareDoc[] = R"DOC(
D
dzhwinter 已提交
227 228 229
Square Activation Operator.

$out = x^2$
230

D
dzhwinter 已提交
231 232
)DOC";

D
dzhwinter 已提交
233
UNUSED constexpr char SoftplusDoc[] = R"DOC(
D
dzhwinter 已提交
234 235 236 237 238 239
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

D
dzhwinter 已提交
240
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
241 242 243 244 245 246 247
Softsign Activation Operator.

$$out = \frac{x}{1 + |x|}$$

)DOC";

class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
248
 public:
Y
Yu Yang 已提交
249
  void Make() override {
D
dzhwinter 已提交
250 251 252
    AddInput("X", "Input of LeakyRelu operator");
    AddOutput("Out", "Output of LeakyRelu operator");
    AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
K
Kexin Zhao 已提交
253
    AddComment(R"DOC(
D
dzhwinter 已提交
254
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
255

D
dzhwinter 已提交
256
$out = \max(x, \alpha * x)$
K
Kexin Zhao 已提交
257 258

)DOC");
259 260 261
  }
};

D
dzhwinter 已提交
262
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
263
 public:
Y
Yu Yang 已提交
264
  void Make() override {
D
dzhwinter 已提交
265 266 267
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
268
    AddComment(R"DOC(
269 270 271 272 273 274 275 276
:strong:`Softshrink Activation Operator`

..  math::
    out = \begin{cases} 
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
277 278

)DOC");
K
kexinzhao 已提交
279 280 281
  }
};

D
dzhwinter 已提交
282
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
283
 public:
Y
Yu Yang 已提交
284
  void Make() override {
D
dzhwinter 已提交
285 286
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
287 288
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
289
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
290
    AddComment(R"DOC(
Y
yuyang18 已提交
291
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
292

Y
yuyang18 已提交
293 294 295 296 297 298
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
299 300

)DOC");
301 302 303
  }
};

304 305
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
306
  void Make() override {
307
    AddInput("X", "Input of BRelu operator");
F
fengjiayi 已提交
308
    AddOutput("Out", "Output of BRelu operator");
309 310 311 312
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
313
    AddComment(R"DOC(
K
kexinzhao 已提交
314
BRelu Activation Operator.
K
Kexin Zhao 已提交
315

F
fengjiayi 已提交
316
$out = \max(\min(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
317 318

)DOC");
319 320 321 322 323
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
324
  void Make() override {
325
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
326
    AddOutput("Out", "Output of SoftRelu operator");
327 328
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
329
    AddComment(R"DOC(
K
kexinzhao 已提交
330
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
331

F
fengjiayi 已提交
332
$out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
K
Kexin Zhao 已提交
333 334

)DOC");
335 336 337
  }
};

338 339
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
340
  void Make() override {
K
Kexin Zhao 已提交
341
    AddInput("X", "Input of ELU operator");
F
fengjiayi 已提交
342
    AddOutput("Out", "Output of ELU operator");
343
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
344
    AddComment(R"DOC(
K
kexinzhao 已提交
345
ELU Activation Operator.
K
Kexin Zhao 已提交
346 347 348 349

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
350
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
351 352

)DOC");
353 354 355
  }
};

356 357
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
358
  void Make() override {
359
    AddInput("X", "Input of Relu6 operator");
F
fengjiayi 已提交
360
    AddOutput("Out", "Output of Relu6 operator");
361 362
    AddAttr<float>("threshold", "The threshold value of Relu6")
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
363
    AddComment(R"DOC(
K
kexinzhao 已提交
364
Relu6 Activation Operator.
K
Kexin Zhao 已提交
365

F
fengjiayi 已提交
366
$out = \min(\max(0, x), 6)$
K
Kexin Zhao 已提交
367 368

)DOC");
369 370 371
  }
};

372 373
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
374
  void Make() override {
375
    AddInput("X", "Input of Pow operator");
F
fengjiayi 已提交
376
    AddOutput("Out", "Output of Pow operator");
377
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
378
    AddComment(R"DOC(
K
kexinzhao 已提交
379
Pow Activation Operator.
K
Kexin Zhao 已提交
380

F
fengjiayi 已提交
381
$out = x^{factor}$
K
Kexin Zhao 已提交
382 383

)DOC");
384 385 386 387 388
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
389
  void Make() override {
390
    AddInput("X", "Input of STanh operator");
F
fengjiayi 已提交
391
    AddOutput("Out", "Output of STanh operator");
392 393 394 395
    AddAttr<float>("scale_a", "The scale parameter of a for the input")
        .SetDefault(2.0f / 3.0f);
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
396
    AddComment(R"DOC(
K
kexinzhao 已提交
397
STanh Activation Operator.
K
Kexin Zhao 已提交
398

Y
Yan Chunwei 已提交
399
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
400 401

)DOC");
Q
qijun 已提交
402 403 404
  }
};

405 406
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
407
  void Make() override {
408
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
409
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
410 411
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
412
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
413
    AddComment(R"DOC(
Y
yuyang18 已提交
414
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
415

Y
yuyang18 已提交
416
..  math::
K
Kexin Zhao 已提交
417

Y
yuyang18 已提交
418
    out = \begin{cases}
Y
yuyang18 已提交
419
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
420 421
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
422
)DOC");
423 424 425
  }
};

426 427
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
428
  void Make() override {
429
    AddInput("X", "Input of HardSigmoid operator");
F
fengjiayi 已提交
430
    AddOutput("Out", "Output of HardSigmoid operator");
431 432 433 434
    AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
        .SetDefault(0.2f);
    AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
        .SetDefault(0.5f);
435
    AddComment(R"DOC(
K
kexinzhao 已提交
436
HardSigmoid Activation Operator.
437

K
Kexin Zhao 已提交
438 439
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391), 
which is much faster than sigmoid.
440

F
fengjiayi 已提交
441
$out = \max(0, \min(1, slope * x + shift))$
442 443

The slope should be positive. The offset can be either positive or negative.
K
Kexin Zhao 已提交
444
The default slope and shift are set according to the above reference.
445 446
It is recommended to use the defaults for this activation.

K
Kexin Zhao 已提交
447
)DOC");
448 449 450
  }
};

A
Abhinav Arora 已提交
451 452
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
453
  void Make() override {
A
Abhinav Arora 已提交
454
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
455
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
456 457 458 459
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

F
fengjiayi 已提交
460
$$out = \\frac{x}{1 + e^{- \beta x}}$$
A
Abhinav Arora 已提交
461 462 463 464 465

)DOC");
  }
};

D
dzhwinter 已提交
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

D
dzhwinter 已提交
485 486
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu);
D
dzhwinter 已提交
487
REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp);
D
dzhwinter 已提交
488 489 490
REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Floor, floor);
D
dzhwinter 已提交
491
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sqrt, sqrt);
D
dzhwinter 已提交
492
REGISTER_ACTIVATION_OP_GRAD_MAKER(SoftRelu, soft_relu);
D
dzhwinter 已提交
493 494
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu6, relu6);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Reciprocal, reciprocal);
D
dzhwinter 已提交
495
REGISTER_ACTIVATION_OP_GRAD_MAKER(HardSigmoid, hard_sigmoid);
Q
qijun 已提交
496 497 498 499
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
500

D
dzhwinter 已提交
501
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
D
dzhwinter 已提交
502
  __macro(Sigmoid, sigmoid);                 \
503
  __macro(Relu, relu);                       \
D
dzhwinter 已提交
504
  __macro(Exp, exp);                         \
505
  __macro(Tanh, tanh);                       \
D
dzhwinter 已提交
506 507
  __macro(Ceil, ceil);                       \
  __macro(Floor, floor);                     \
508
  __macro(Sqrt, sqrt);                       \
D
dzhwinter 已提交
509 510 511 512
  __macro(SoftRelu, soft_relu);              \
  __macro(Relu6, relu6);                     \
  __macro(Reciprocal, reciprocal);           \
  __macro(HardSigmoid, hard_sigmoid);
D
dzhwinter 已提交
513 514

#define FOR_EACH_OP_FUNCTOR(__macro) \
D
dzhwinter 已提交
515 516
  __macro(LogSigmoid, logsigmoid);   \
  __macro(SoftShrink, softshrink);   \
517
  __macro(Abs, abs);                 \
D
dzhwinter 已提交
518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537
  __macro(Cos, cos);                 \
  __macro(Sin, sin);                 \
  __macro(Round, round);             \
  __macro(Log, log);                 \
  __macro(Square, square);           \
  __macro(BRelu, brelu);             \
  __macro(Pow, pow);                 \
  __macro(STanh, stanh);             \
  __macro(Softplus, softplus);       \
  __macro(Softsign, softsign);       \
  __macro(LeakyRelu, leaky_relu);    \
  __macro(TanhShrink, tanh_shrink);  \
  __macro(ELU, elu);                 \
  __macro(HardShrink, hard_shrink);  \
  __macro(Swish, swish);             \
  __macro(ThresholdedRelu, thresholded_relu);

#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE)        \
  REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
                    ::paddle::operators::OP_NAME##OpMaker,          \
538
                    ::paddle::operators::ActivationOpInferVarType,  \
D
dzhwinter 已提交
539 540 541
                    ::paddle::operators::OP_NAME##GradMaker);       \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)

D
dzhwinter 已提交
542 543 544
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE)                    \
  REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp,     \
                    ::paddle::operators::OP_NAME##OpMaker,              \
545
                    ::paddle::operators::ActivationOpInferVarType,      \
D
dzhwinter 已提交
546 547
                    ::paddle::framework::DefaultGradOpDescMaker<true>); \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
A
Abhinav Arora 已提交
548

Q
QI JUN 已提交
549 550 551 552 553 554 555 556 557 558 559
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor)   \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
560
                                ops::grad_functor<double>>);
561

D
dzhwinter 已提交
562
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP);
D
dzhwinter 已提交
563
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP);
564
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);