activation_op.cc 28.4 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
T
tink2123 已提交
16
#include <memory>
D
dzhwinter 已提交
17
#include <string>
18
#include <type_traits>
T
tink2123 已提交
19
#include <unordered_map>
20
#include <vector>
21
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
D
dzhwinter 已提交
22
#include "paddle/fluid/platform/port.h"
23 24 25
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
Q
qijun 已提交
26 27 28 29

namespace paddle {
namespace operators {

30 31
using paddle::framework::Tensor;

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
  return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}

std::unique_ptr<std::unordered_set<std::string>> GetInplaceOpSet() {
  std::unique_ptr<std::unordered_set<std::string>> ret(
      new std::unordered_set<std::string>());
#define INSERT_INTO_INPLACE_OP_SET(op_type, __omitted, fwd_functor, \
                                   bwd_functor)                     \
  if (CanInplaceAct<bwd_functor<float>>()) {                        \
    ret->insert(#op_type);                                          \
  }

  FOR_EACH_ACTIVATION_OP(INSERT_INTO_INPLACE_OP_SET);
#undef INSERT_INTO_INPLACE_OP_SET
  return ret;
}

51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)                    \
  class OP_NAME##OpMaker                                                     \
      : public ::paddle::framework::OpProtoAndCheckerMaker {                 \
   public:                                                                   \
    void Make() override {                                                   \
      AddInput("X", "Input of " #OP_NAME " operator");                       \
      AddOutput("Out", "Output of " #OP_NAME " operator");                   \
      AddAttr<bool>("use_mkldnn",                                            \
                    "(bool, default false) Only used in mkldnn kernel")      \
          .SetDefault(false);                                                \
      AddAttr<bool>("use_cudnn",                                             \
                    "(bool, default false) Only used in cudnn kernel, need " \
                    "install cudnn")                                         \
          .SetDefault(false);                                                \
      AddAttr<bool>(                                                         \
          "is_test",                                                         \
          "(bool, default false) Set to true for inference only, false "     \
          "for training. Some layers may run faster when this is true.")     \
          .SetDefault(false);                                                \
      AddComment(OP_COMMENT);                                                \
    }                                                                        \
D
dzhwinter 已提交
72
  }
D
dzhwinter 已提交
73

74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
template <ActBwdOpFwdDeps kDepValue>
class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<framework::OpDesc> Apply() const override {
    std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
    op->SetType(ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
    op->SetAttrMap(Attrs());

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
      op->SetInput("X", Input("X"));
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
      op->SetInput("Out", Output("Out"));
    }

    return op;
D
dzhwinter 已提交
98
  }
99
};
D
dzhwinter 已提交
100

101 102 103 104
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
  framework::LibraryType library{framework::LibraryType::kPlain};
M
mozga-intel 已提交
105
  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
106 107 108 109 110 111 112 113 114 115
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
// and re-enable this in the future
// #ifdef PADDLE_WITH_CUDA
//   auto it1 = oper.Attrs().find("use_cudnn");
//   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
//     library = framework::LibraryType::kCUDNN;
//   }
// #endif
116 117 118 119 120
#ifdef PADDLE_WITH_MKLDNN
  auto it = oper.Attrs().find("use_mkldnn");
  if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
      platform::CanMKLDNNBeUsed(ctx)) {
    library = framework::LibraryType::kMKLDNN;
M
mozga-intel 已提交
121
    layout = framework::DataLayout::kMKLDNN;
122 123 124
  }
#endif
  return framework::OpKernelType(
C
chengduo 已提交
125 126
      framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout,
      library);
127 128
}

Q
qijun 已提交
129 130 131 132
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

133
  void InferShape(framework::InferShapeContext* ctx) const override {
134
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
135
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
136
  }
137

138
 protected:
139 140 141 142
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
143 144
};

C
chengduo 已提交
145 146 147 148 149 150
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
  std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
      const override {
    return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
151 152 153
  }
};

Q
qijun 已提交
154 155 156 157
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

158
  void InferShape(framework::InferShapeContext* ctx) const override {
159 160 161
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
162
  }
163

164
 protected:
165 166
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
167
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
168
  }
Q
qijun 已提交
169 170
};

D
dzhwinter 已提交
171
UNUSED constexpr char SigmoidDoc[] = R"DOC(
172
Sigmoid Activation Operator
K
Kexin Zhao 已提交
173

174
$$out = \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
175

D
dzhwinter 已提交
176
)DOC";
Q
qijun 已提交
177

D
dzhwinter 已提交
178
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
179
Logsigmoid Activation Operator
K
Kexin Zhao 已提交
180

181
$$out = \\log \\frac{1}{1 + e^{-x}}$$
K
Kexin Zhao 已提交
182

D
dzhwinter 已提交
183
)DOC";
184

D
dzhwinter 已提交
185
UNUSED constexpr char ExpDoc[] = R"DOC(
K
kexinzhao 已提交
186
Exp Activation Operator.
K
Kexin Zhao 已提交
187

F
fengjiayi 已提交
188
$out = e^x$
K
Kexin Zhao 已提交
189

D
dzhwinter 已提交
190
)DOC";
Q
qijun 已提交
191

D
dzhwinter 已提交
192
UNUSED constexpr char ReluDoc[] = R"DOC(
K
kexinzhao 已提交
193
Relu Activation Operator.
K
Kexin Zhao 已提交
194

F
fengjiayi 已提交
195
$out = \max(x, 0)$
K
Kexin Zhao 已提交
196

D
dzhwinter 已提交
197
)DOC";
K
Kexin Zhao 已提交
198

C
Clementine 已提交
199 200 201 202 203 204 205
UNUSED constexpr char GeluDoc[] = R"DOC(
Gelu Activation Operator.

$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$

)DOC";

D
dzhwinter 已提交
206
UNUSED constexpr char TanhDoc[] = R"DOC(
K
kexinzhao 已提交
207
Tanh Activation Operator.
K
Kexin Zhao 已提交
208

Q
update  
qiaolongfei 已提交
209
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
210

D
dzhwinter 已提交
211
)DOC";
212

D
dzhwinter 已提交
213
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
K
kexinzhao 已提交
214
TanhShrink Activation Operator.
K
Kexin Zhao 已提交
215

Y
Yan Chunwei 已提交
216
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
K
Kexin Zhao 已提交
217

D
dzhwinter 已提交
218
)DOC";
K
Kexin Zhao 已提交
219

D
dzhwinter 已提交
220
UNUSED constexpr char SqrtDoc[] = R"DOC(
K
kexinzhao 已提交
221
Sqrt Activation Operator.
K
Kexin Zhao 已提交
222

223 224 225
Please make sure legal input, when input a negative value closed to zero,
you should add a small epsilon(1e-12) to avoid negative number caused by numerical errors.

F
fengjiayi 已提交
226
$out = \sqrt{x}$
K
Kexin Zhao 已提交
227

D
dzhwinter 已提交
228
)DOC";
229

Z
zhoukunsheng 已提交
230 231 232 233 234 235 236 237 238
UNUSED constexpr char RsqrtDoc[] = R"DOC(
Rsqrt Activation Operator.

Please make sure input is legal in case of numeric errors.

$out = \frac{1}{\sqrt{x}}$

)DOC";

D
dzhwinter 已提交
239
UNUSED constexpr char AbsDoc[] = R"DOC(
K
kexinzhao 已提交
240
Abs Activation Operator.
K
Kexin Zhao 已提交
241

F
fengjiayi 已提交
242
$out = |x|$
K
Kexin Zhao 已提交
243

D
dzhwinter 已提交
244
)DOC";
245

D
dzhwinter 已提交
246
UNUSED constexpr char CeilDoc[] = R"DOC(
D
dzhwinter 已提交
247 248
Ceil Activation Operator.

249
$out = \left \lceil x \right \rceil$
D
dzhwinter 已提交
250

D
dzhwinter 已提交
251
)DOC";
D
dzhwinter 已提交
252

D
dzhwinter 已提交
253
UNUSED constexpr char FloorDoc[] = R"DOC(
D
dzhwinter 已提交
254 255
Floor Activation Operator.

256
$out = \left \lfloor x \right \rfloor$
D
dzhwinter 已提交
257

D
dzhwinter 已提交
258
)DOC";
D
dzhwinter 已提交
259

D
dzhwinter 已提交
260
UNUSED constexpr char CosDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
261
Cosine Activation Operator.
C
add cos  
chengduoZH 已提交
262 263 264

$out = cos(x)$

D
dzhwinter 已提交
265
)DOC";
C
add cos  
chengduoZH 已提交
266

D
dzhwinter 已提交
267
UNUSED constexpr char SinDoc[] = R"DOC(
C
add sin  
chengduoZH 已提交
268 269 270 271
Sine Activation Operator.

$out = sin(x)$

D
dzhwinter 已提交
272
)DOC";
C
add sin  
chengduoZH 已提交
273

D
dzhwinter 已提交
274
UNUSED constexpr char RoundDoc[] = R"DOC(
D
dzhwinter 已提交
275 276
Round Activation Operator.

F
fengjiayi 已提交
277
$out = [x]$
D
dzhwinter 已提交
278

D
dzhwinter 已提交
279
)DOC";
D
dzhwinter 已提交
280

D
dzhwinter 已提交
281
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
K
kexinzhao 已提交
282
Reciprocal Activation Operator.
K
Kexin Zhao 已提交
283

284
$$out = \\frac{1}{x}$$
K
Kexin Zhao 已提交
285

D
dzhwinter 已提交
286
)DOC";
287

D
dzhwinter 已提交
288
UNUSED constexpr char LogDoc[] = R"DOC(
K
kexinzhao 已提交
289
Log Activation Operator.
K
Kexin Zhao 已提交
290

F
fengjiayi 已提交
291
$out = \ln(x)$
K
Kexin Zhao 已提交
292 293 294

Natural logarithm of x.

D
dzhwinter 已提交
295 296
)DOC";

D
dzhwinter 已提交
297
UNUSED constexpr char SquareDoc[] = R"DOC(
D
dzhwinter 已提交
298 299 300
Square Activation Operator.

$out = x^2$
301

D
dzhwinter 已提交
302 303
)DOC";

D
dzhwinter 已提交
304
UNUSED constexpr char SoftplusDoc[] = R"DOC(
D
dzhwinter 已提交
305 306 307 308 309 310
Softplus Activation Operator.

$out = \ln(1 + e^{x})$

)DOC";

D
dzhwinter 已提交
311
UNUSED constexpr char SoftsignDoc[] = R"DOC(
D
dzhwinter 已提交
312 313
Softsign Activation Operator.

314
$$out = \\frac{x}{1 + \|x\|}$$
D
dzhwinter 已提交
315 316 317

)DOC";

T
tink2123 已提交
318 319 320 321 322 323
class AcosOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of acos operator");
    AddOutput("Out", "Output of acos operator");
    AddComment(R"DOC(
324 325
Arccosine Activation Operator.

T
tink2123 已提交
326
$$out = \cos^{-1}(x)$$
327

T
tink2123 已提交
328 329 330
)DOC");
  }
};
331

T
tink2123 已提交
332 333 334 335 336 337
class AsinOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of asin operator");
    AddOutput("Out", "Output of asin operator");
    AddComment(R"DOC(
338 339
Arcsine Activation Operator.

T
tink2123 已提交
340
$$out = \sin^{-1}(x)$$
341

T
tink2123 已提交
342 343 344
)DOC");
  }
};
345

T
tink2123 已提交
346 347 348 349 350 351
class AtanOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of atan operator");
    AddOutput("Out", "Output of atan operator");
    AddComment(R"DOC(
352 353
Arctanh Activation Operator.

T
tink2123 已提交
354
$$out = \tanh^{-1}(x)$$
355

T
tink2123 已提交
356 357 358
)DOC");
  }
};
359

D
dzhwinter 已提交
360
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
361
 public:
Y
Yu Yang 已提交
362
  void Make() override {
D
dzhwinter 已提交
363 364 365
    AddInput("X", "Input of LeakyRelu operator");
    AddOutput("Out", "Output of LeakyRelu operator");
    AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
K
Kexin Zhao 已提交
366
    AddComment(R"DOC(
D
dzhwinter 已提交
367
LeakyRelu Activation Operator.
K
Kexin Zhao 已提交
368

D
dzhwinter 已提交
369
$out = \max(x, \alpha * x)$
K
Kexin Zhao 已提交
370 371

)DOC");
372 373 374
  }
};

D
dzhwinter 已提交
375
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
K
kexinzhao 已提交
376
 public:
Y
Yu Yang 已提交
377
  void Make() override {
D
dzhwinter 已提交
378 379 380
    AddInput("X", "Input of Softshrink operator");
    AddOutput("Out", "Output of Softshrink operator");
    AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
K
Kexin Zhao 已提交
381
    AddComment(R"DOC(
382 383 384
:strong:`Softshrink Activation Operator`

..  math::
385
    out = \begin{cases}
386 387 388 389
         x - \lambda, \text{if } x > \lambda \\
         x + \lambda, \text{if } x < -\lambda \\
         0,  \text{otherwise}
         \end{cases}
K
Kexin Zhao 已提交
390 391

)DOC");
K
kexinzhao 已提交
392 393 394
  }
};

D
dzhwinter 已提交
395
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
396
 public:
Y
Yu Yang 已提交
397
  void Make() override {
D
dzhwinter 已提交
398 399
    AddInput("X", "Input of HardShrink operator");
    AddOutput("Out", "Output of HardShrink operator");
Y
yuyang18 已提交
400 401
    AddAttr<float>("threshold",
                   "The value of threshold for HardShrink. [default: 0.5]")
D
dzhwinter 已提交
402
        .SetDefault(0.5f);
K
Kexin Zhao 已提交
403
    AddComment(R"DOC(
Y
yuyang18 已提交
404
:strong:`HardShrink activation operator`
K
Kexin Zhao 已提交
405

Y
yuyang18 已提交
406 407 408 409 410 411
..  math::
    out = \begin{cases}
            x, \text{if } x > \lambda \\
            x, \text{if } x < -\lambda \\
            0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
412 413

)DOC");
414 415 416
  }
};

417 418
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
419
  void Make() override {
420
    AddInput("X", "Input of BRelu operator");
F
fengjiayi 已提交
421
    AddOutput("Out", "Output of BRelu operator");
422 423 424 425
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
426
    AddComment(R"DOC(
K
kexinzhao 已提交
427
BRelu Activation Operator.
K
Kexin Zhao 已提交
428

F
fengjiayi 已提交
429
$out = \max(\min(x, t_{min}), t_{max})$
K
Kexin Zhao 已提交
430 431

)DOC");
432 433 434 435 436
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
437
  void Make() override {
438
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
439
    AddOutput("Out", "Output of SoftRelu operator");
440 441
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
442
    AddComment(R"DOC(
K
kexinzhao 已提交
443
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
444

445
$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$
K
Kexin Zhao 已提交
446 447

)DOC");
448 449 450
  }
};

451 452
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
453
  void Make() override {
K
Kexin Zhao 已提交
454
    AddInput("X", "Input of ELU operator");
F
fengjiayi 已提交
455
    AddOutput("Out", "Output of ELU operator");
456
    AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
457
    AddComment(R"DOC(
K
kexinzhao 已提交
458
ELU Activation Operator.
K
Kexin Zhao 已提交
459 460 461 462

Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.

F
fengjiayi 已提交
463
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
K
Kexin Zhao 已提交
464 465

)DOC");
466 467 468
  }
};

469 470
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
471
  void Make() override {
472
    AddInput("X", "Input of Relu6 operator");
F
fengjiayi 已提交
473
    AddOutput("Out", "Output of Relu6 operator");
474 475
    AddAttr<float>("threshold", "The threshold value of Relu6")
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
476
    AddComment(R"DOC(
K
kexinzhao 已提交
477
Relu6 Activation Operator.
K
Kexin Zhao 已提交
478

F
fengjiayi 已提交
479
$out = \min(\max(0, x), 6)$
K
Kexin Zhao 已提交
480 481

)DOC");
482 483 484
  }
};

485 486
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
487
  void Make() override {
488
    AddInput("X", "Input of Pow operator");
F
fengjiayi 已提交
489
    AddOutput("Out", "Output of Pow operator");
490
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
491
    AddComment(R"DOC(
K
kexinzhao 已提交
492
Pow Activation Operator.
K
Kexin Zhao 已提交
493

F
fengjiayi 已提交
494
$out = x^{factor}$
K
Kexin Zhao 已提交
495 496

)DOC");
497 498 499 500 501
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
502
  void Make() override {
503
    AddInput("X", "Input of STanh operator");
F
fengjiayi 已提交
504
    AddOutput("Out", "Output of STanh operator");
505 506 507 508
    AddAttr<float>("scale_a", "The scale parameter of a for the input")
        .SetDefault(2.0f / 3.0f);
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
509
    AddComment(R"DOC(
K
kexinzhao 已提交
510
STanh Activation Operator.
K
Kexin Zhao 已提交
511

Y
Yan Chunwei 已提交
512
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
513 514

)DOC");
Q
qijun 已提交
515 516 517
  }
};

518 519
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
520
  void Make() override {
521
    AddInput("X", "Input of ThresholdedRelu operator");
F
fengjiayi 已提交
522
    AddOutput("Out", "Output of ThresholdedRelu operator");
Y
yuyang18 已提交
523 524
    AddAttr<float>("threshold",
                   "The threshold location of activation. [default 1.0].")
525
        .SetDefault(1.0f);
K
Kexin Zhao 已提交
526
    AddComment(R"DOC(
Y
yuyang18 已提交
527
:strong:`ThresholdedRelu activation operator`
K
Kexin Zhao 已提交
528

Y
yuyang18 已提交
529
..  math::
K
Kexin Zhao 已提交
530

Y
yuyang18 已提交
531
    out = \begin{cases}
Y
yuyang18 已提交
532
             x,  \text{if } x > threshold \\
Y
yuyang18 已提交
533 534
             0,  \text{otherwise}
          \end{cases}
K
Kexin Zhao 已提交
535
)DOC");
536 537 538
  }
};

539 540
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
541
  void Make() override {
542
    AddInput("X", "Input of HardSigmoid operator");
F
fengjiayi 已提交
543
    AddOutput("Out", "Output of HardSigmoid operator");
544 545 546 547
    AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
        .SetDefault(0.2f);
    AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
        .SetDefault(0.5f);
548
    AddComment(R"DOC(
K
kexinzhao 已提交
549
HardSigmoid Activation Operator.
550

551
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
K
Kexin Zhao 已提交
552
which is much faster than sigmoid.
553

F
fengjiayi 已提交
554
$out = \max(0, \min(1, slope * x + shift))$
555 556

The slope should be positive. The offset can be either positive or negative.
K
Kexin Zhao 已提交
557
The default slope and shift are set according to the above reference.
558 559
It is recommended to use the defaults for this activation.

K
Kexin Zhao 已提交
560
)DOC");
561 562 563
  }
};

A
Abhinav Arora 已提交
564 565
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
566
  void Make() override {
A
Abhinav Arora 已提交
567
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
568
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
569 570 571 572
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

F
fengjiayi 已提交
573
$$out = \\frac{x}{1 + e^{- \beta x}}$$
A
Abhinav Arora 已提交
574 575 576 577 578

)DOC");
  }
};

D
dzhwinter 已提交
579 580 581 582
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
C
Clementine 已提交
583
REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc);
D
dzhwinter 已提交
584 585 586
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
Z
zhoukunsheng 已提交
587
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
D
dzhwinter 已提交
588 589 590 591 592 593 594 595 596 597 598 599
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);

600
template <ActBwdOpFwdDeps kDepValue>
601 602 603 604 605
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
606
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
607
      if (ctx->HasOutput("DX")) {
608 609 610
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
611
      if (ctx->HasOutput("DDOut")) {
612 613 614
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
615
    }
616
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
617
      if (ctx->HasOutput("DOut")) {
618 619 620
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
    if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
      if (ctx->HasOutput("DDOut")) {
649 650 651
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
652 653 654 655 656 657 658 659 660 661
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

662 663 664 665 666 667 668 669 670 671 672 673 674 675
//
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
//
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("relu_grad_grad");
    // input1: Out
    op->SetInput("Out", Input("Out"));
Q
qingqing01 已提交
676
    // input2: ddx
677 678
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
679
    // output: ddy
680 681 682 683 684
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706
// leaky_relu Grad: dx=dy if y>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx
class LeakyReluDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("leaky_relu_grad_grad");
    // input1: X
    op->SetInput("X", Input("X"));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

L
lvmengsi 已提交
707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("sqrt_grad_grad");
    op->SetInput("Out", Input("Out"));
    op->SetInput("DX", Output(framework::GradVarName("X")));
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
    op->SetAttrMap(Attrs());
    op->SetOutput("DOut", InputGrad("Out"));
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
class SquareDoubleGradMaker
    : public ::paddle::framework::SingleGradOpDescMaker {
 public:
  using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

 protected:
  std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
    auto* op = new ::paddle::framework::OpDesc();
    op->SetType("square_grad_grad");
    op->SetInput("X", Input("X"));
    // Out@GRAD: dy
    op->SetInput("DOut", Input(framework::GradVarName("Out")));
    // X@GRAD@GRAD: ddx
    op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));

    op->SetAttrMap(Attrs());

    // X@GRAD: dx
    op->SetOutput("DX", InputGrad("X"));
    // Out@GRAD@GRAD: ddy
    op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
    return std::unique_ptr<::paddle::framework::OpDesc>(op);
  }
};

Q
qijun 已提交
754 755 756 757
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
758
namespace plat = paddle::platform;
759

760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
      KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker,                \
      ops::ActivationOpInferVarType,                                        \
      ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>,  \
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
                       ::paddle::framework::SingleOpInplaceInToOut,         \
                       void>::type);                                        \
  REGISTER_OPERATOR(                                                        \
      KERNEL_TYPE##_grad, ops::ActivationOpGrad,                            \
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
                       ::paddle::framework::SingleOpInplaceInToOut,         \
                       void>::type)

#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor,        \
                                       grad_functor)                      \
Q
QI JUN 已提交
776 777 778 779 780 781 782 783 784 785
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
                                      ops::functor<float>>,               \
      ops::ActivationKernel<paddle::platform::CPUDeviceContext,           \
                            ops::functor<double>>);                       \
  REGISTER_OP_CPU_KERNEL(                                                 \
      act_type##_grad,                                                    \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
                                ops::grad_functor<float>>,                \
      ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,       \
Y
Yu Yang 已提交
786
                                ops::grad_functor<double>>);
787

788 789
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
790

791
/* ==========================    relu register  ============================= */
792 793 794 795 796 797 798
REGISTER_OPERATOR(
    relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
                  paddle::framework::SingleOpInplaceInToOut,
                  ops::ReluDoubleGradMaker);
799 800
REGISTER_OPERATOR(
    relu_grad_grad,
801
    ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>);
802 803 804 805 806 807 808 809 810 811 812

REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);

REGISTER_OP_CPU_KERNEL(
    relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::ReluGradGradFunctor<plat::float16>>);
813
/* ========================================================================== */
814

815
/* ======================== leaky relu register  ============================ */
816 817 818 819 820 821 822 823
REGISTER_OPERATOR(
    leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
                  paddle::framework::SingleOpInplaceInToOut,
                  ops::LeakyReluDoubleGradMaker);
824 825
REGISTER_OPERATOR(
    leaky_relu_grad_grad,
826
    ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>);
827

828 829 830 831 832 833 834 835 836 837
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
                               LeakyReluGradFunctor);
REGISTER_OP_CPU_KERNEL(
    leaky_relu_grad_grad,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<float>>,
    ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
                                    ops::LeakyReluGradGradFunctor<double>>,
    ops::ActivationDoubleGradKernel<
        plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
838 839
/* ========================================================================== */

L
lvmengsi 已提交
840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860
/* ===========================   sqrt register  ============================= */
REGISTER_OPERATOR(
    sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SqrtGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
                  paddle::framework::SingleOpInplaceInToOut,
                  ops::SqrtDoubleGradMaker);
REGISTER_OPERATOR(
    sqrt_grad_grad,
    ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
    sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                                              ops::SqrtGradGradFunctor<float>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<double>>,
    ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
                              ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */

861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885
/* ==========================   square register  ============================ */
REGISTER_OPERATOR(
    square, ops::ActivationOp, ops::SquareOpMaker,
    ops::ActivationOpInferVarType,
    ops::ActivationGradOpDescMaker<ops::SquareGradFunctor<float>::FwdDeps()>,
    paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
                  paddle::framework::SingleOpInplaceInToOut,
                  ops::SquareDoubleGradMaker);
REGISTER_OPERATOR(
    square_grad_grad,
    ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>);

REGISTER_ACTIVATION_CPU_KERNEL(square, Square, SquareFunctor,
                               SquareGradFunctor);

REGISTER_OP_CPU_KERNEL(
    square_grad_grad,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<float>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<double>>,
    ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
                                ops::SquareGradGradFunctor<plat::float16>>);
/* ========================================================================== */