activation_op.cc 21.1 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
16

T
tink2123 已提交
17
#include <memory>
D
dzhwinter 已提交
18
#include <string>
19
#include <type_traits>
T
tink2123 已提交
20
#include <unordered_map>
21
#include <vector>
22

23
#include "paddle/fluid/framework/op_version_registry.h"
24
#include "paddle/fluid/operators/common_infer_shape_functions.h"
25
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
26
#include "paddle/phi/backends/dynload/port.h"
Q
qijun 已提交
27

A
Adam 已提交
28 29
DECLARE_bool(use_mkldnn);

Q
qijun 已提交
30 31 32
namespace paddle {
namespace operators {

33 34
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
35 36
  return GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kDepOut ||
         GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kNoDeps;
37 38
}

39 40 41 42 43 44 45 46 47 48 49 50 51 52
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)           \
  class OP_NAME##OpMaker                                            \
      : public ::paddle::framework::OpProtoAndCheckerMaker {        \
   public:                                                          \
    void Make() override {                                          \
      AddInput("X",                                                 \
               "Input of " #OP_NAME                                 \
               " operator, an N-D Tensor, with data type float32, " \
               "float64 or float16.");                              \
      AddOutput("Out",                                              \
                "Output of " #OP_NAME                               \
                " operator, a Tensor with shape same as input.");   \
      AddComment(OP_COMMENT);                                       \
    }                                                               \
D
dzhwinter 已提交
53
  }
D
dzhwinter 已提交
54

H
hong 已提交
55 56
template <ActBwdOpFwdDeps kDepValue, typename T>
class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
57
 public:
H
hong 已提交
58
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
59 60

 protected:
61
  void Apply(GradOpPtr<T> op) const override {
H
hong 已提交
62 63 64 65
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
66

A
Adam 已提交
67 68
    if ((static_cast<int>(kDepValue) &
         static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
69 70
        FLAGS_use_mkldnn ||
        (op->HasAttr("use_mkldnn") &&
R
Ruibiao Chen 已提交
71
         PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")))) {
72
      op->SetInput("X", this->Input("X"));  // x
73 74 75 76
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
77
      op->SetInput("Out", this->Output("Out"));  // out
78
    }
D
dzhwinter 已提交
79
  }
80
};
D
dzhwinter 已提交
81

82 83 84
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
                                      const framework::OperatorWithKernel& oper,
                                      const std::string& name) {
85
  auto data_type = oper.IndicateVarDataType(ctx, name);
86 87 88 89 90 91 92 93 94 95 96
  // FIXME(liuwei1031) temporarily disable the code to unblock users
  // TODO(liuwei1031) figure out the reason behind
  // https://github.com/PaddlePaddle/Paddle/issues/16096
  // and re-enable this in the future
  // #ifdef PADDLE_WITH_CUDA
  //   auto it1 = oper.Attrs().find("use_cudnn");
  //   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
  //     library = framework::LibraryType::kCUDNN;
  //   }
  // #endif
  return framework::OpKernelType(data_type, ctx.GetPlace());
97 98
}

Q
qijun 已提交
99 100 101 102
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

103
  void InferShape(framework::InferShapeContext* ctx) const override {
104
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
105
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
106
  }
107

108
 protected:
109 110 111 112
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
113 114
};

C
chengduo 已提交
115 116 117
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
118
  std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
C
chengduo 已提交
119
      const override {
120 121
    static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
    return m;
122 123 124
  }
};

Q
qijun 已提交
125 126 127 128
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

129
  void InferShape(framework::InferShapeContext* ctx) const override {
130 131 132
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
133
  }
134

135
 protected:
136 137
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
138
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
139
  }
Q
qijun 已提交
140 141
};

142 143
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
144
  void Make() override {
145 146 147 148 149 150
    AddInput("X",
             "The input is a multi-dimensional Tensor. The data type is "
             "float32, float64.");
    AddOutput("Out",
              "The output is a multi-dimensional Tensor which has same "
              "dimension and data type as the ``X``.");
151 152 153 154
    AddAttr<float>("t_min", "The min marginal value of BRelu")
        .SetDefault(static_cast<float>(0));
    AddAttr<float>("t_max", "The max marginal value of BRelu")
        .SetDefault(static_cast<float>(24));
K
Kexin Zhao 已提交
155
    AddComment(R"DOC(
K
kexinzhao 已提交
156
BRelu Activation Operator.
K
Kexin Zhao 已提交
157

158
$$out = \min(\max(x, t_{min}), t_{max})$$
K
Kexin Zhao 已提交
159 160

)DOC");
161 162 163 164 165
  }
};

class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
166
  void Make() override {
167
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
168
    AddOutput("Out", "Output of SoftRelu operator");
169 170
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
171
    AddComment(R"DOC(
K
kexinzhao 已提交
172
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
173

174
$$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$$
K
Kexin Zhao 已提交
175 176

)DOC");
177 178 179
  }
};

180 181
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
182
  void Make() override {
Z
zhupengyang 已提交
183 184 185 186 187 188 189 190
    AddInput("X",
             "Input of relu6 operator, an N-D Tensor, "
             "with data type float32, float64.");
    AddOutput(
        "Out",
        "Output of relu6 operator, a Tensor with the same shape as input.");
    AddAttr<float>("threshold",
                   "The threshold value of Relu6. Default is 6.0. ")
191
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
192
    AddComment(R"DOC(
K
kexinzhao 已提交
193
Relu6 Activation Operator.
K
Kexin Zhao 已提交
194

195
$$out = \min(\max(0, x), threshold)$$
K
Kexin Zhao 已提交
196 197

)DOC");
198 199 200
  }
};

201 202
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
203
  void Make() override {
204
    AddInput("X", "Input of Pow operator");
205 206 207 208 209
    AddInput("FactorTensor",
             "(Tensor<float>, optional). If provided, pow will use this"
             "The shape of FactorTensor MUST BE [1]."
             "it has higher priority than attr(factor).")
        .AsDispensable();
F
fengjiayi 已提交
210
    AddOutput("Out", "Output of Pow operator");
211
    AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
K
Kexin Zhao 已提交
212
    AddComment(R"DOC(
K
kexinzhao 已提交
213
Pow Activation Operator.
K
Kexin Zhao 已提交
214

215
$$out = x^{factor}$$
K
Kexin Zhao 已提交
216 217

)DOC");
218 219 220 221 222
  }
};

class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
223
  void Make() override {
224 225
    AddInput("X",
             "Input of STanh operator."
N
Noel 已提交
226
             " A Tensor with type float32, float64.");
227 228 229
    AddOutput("Out", "Output of STanh operator. A Tensor with type float32.");
    AddAttr<float>("scale_a", "The scale parameter of a for the input. ")
        .SetDefault(0.67f);
230 231
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
232
    AddComment(R"DOC(
K
kexinzhao 已提交
233
STanh Activation Operator.
K
Kexin Zhao 已提交
234

Y
Yan Chunwei 已提交
235
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
236 237

)DOC");
Q
qijun 已提交
238 239 240
  }
};

A
Abhinav Arora 已提交
241 242
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
243
  void Make() override {
A
Abhinav Arora 已提交
244
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
245
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
246 247 248 249
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

250
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
A
Abhinav Arora 已提交
251 252 253 254 255

)DOC");
  }
};

256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
class MishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of Mish operator");
    AddOutput("Out", "Output of Mish operator");
    AddAttr<float>(
        "threshold",
        "Constant threshold of softplus in Mish operator. Approximate value "
        "of softplus will be used if absolute value of input is greater than "
        ":attr:`threshold`")
        .SetDefault(20.f);
    AddComment(R"DOC(
Mish Activation Operator.

..  math::
    softplus(x) = \begin{cases}
            x, \text{if } x > \text{threshold} \\
            \ln(1 + e^{x}),  \text{otherwise}
          \end{cases}

    out = x * \tanh(softplus(x))

)DOC");
  }
};

H
huangjun12 已提交
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of HardSwish operator");
    AddOutput("Out", "Output of HardSwish operator");
    AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("scale", "The scale parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("offset", "The offset parameter of HardSwish operator")
        .SetDefault(3.0f);
    AddComment(R"DOC(
HardSwish Activation Operator.

The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).

298
$$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$$
H
huangjun12 已提交
299 300 301 302 303 304 305 306 307

The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.

)DOC");
  }
};

308
template <ActBwdOpFwdDeps kDepValue>
309 310 311 312 313
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
314 315
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
316
      if (ctx->HasOutput("DX")) {
317 318 319
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
320
      if (ctx->HasOutput("DDOut")) {
321 322 323
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
324
    }
325 326
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
327
      if (ctx->HasOutput("DOut")) {
328 329 330
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
331 332 333 334
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
335 336 337 338
      if (ctx->HasOutput("DOutNew")) {
        ctx->ShareDim("Out", "DOutNew");
        ctx->ShareLoD("Out", "DOutNew");
      }
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
355 356
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
357 358 359 360 361
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
362 363
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
364
      if (ctx->HasOutput("DDOut")) {
365 366 367
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
368 369 370 371 372 373 374 375 376 377
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

378 379 380 381 382 383
template <ActBwdOpFwdDeps kDepValue>
class ActivationOpTripleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
384 385
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
386 387 388 389 390 391 392 393 394
      if (ctx->HasOutput("DX")) {
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
395 396
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
      if (ctx->HasOutput("D_DOut")) {
        ctx->ShareDim("Out", "D_DOut");
        ctx->ShareLoD("Out", "D_DOut");
      }
      if (ctx->HasOutput("D_OutNew")) {
        ctx->ShareDim("Out", "D_OutNew");
        ctx->ShareLoD("Out", "D_OutNew");
      }
      if (ctx->HasOutput("D_DDx")) {
        ctx->ShareDim("DDX", "D_DDx");
        ctx->ShareLoD("DDX", "D_DDx");
      }
    }
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

419
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInferer,
420 421
                           {framework::GradVarName("Out"),  // dout
                            framework::GradVarName("X")});  // dx
422
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInferer,
423
                           {"DDX", "DDOut"});
424 425
DECLARE_INPLACE_OP_INFERER(ActivationTripleGradOpInplaceInferer,
                           {"DDX", "D_DOut"});
426

H
hong 已提交
427 428
template <typename T>
class PowGradOpMaker : public framework::SingleGradOpMaker<T> {
429
 public:
H
hong 已提交
430
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
431 432

 protected:
433
  void Apply(GradOpPtr<T> op) const override {
434
    op->SetType("pow_grad");
H
hong 已提交
435 436 437 438 439
    op->SetInput("X", this->Input("X"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetInput("FactorTensor", this->Input("FactorTensor"));
    op->SetAttrMap(this->Attrs());
440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
  }
};
class PowOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    ctx->ShareDim("X", /*->*/ "Out");
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }

  framework::OpKernelType GetKernelTypeForVar(
458
      const std::string& var_name,
459
      const phi::DenseTensor& tensor,
460 461 462 463
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
464 465
    return framework::OpKernelType(
        expected_kernel_type.data_type_, tensor.place(), tensor.layout());
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
  }
};

class PowOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
  }

  framework::OpKernelType GetKernelTypeForVar(
486
      const std::string& var_name,
487
      const phi::DenseTensor& tensor,
488 489 490 491
      const framework::OpKernelType& expected_kernel_type) const override {
    if (var_name == "FactorTensor") {
      return expected_kernel_type;
    }
492 493
    return framework::OpKernelType(
        expected_kernel_type.data_type_, tensor.place(), tensor.layout());
494 495
  }
};
496
DECLARE_INPLACE_OP_INFERER(ActFwdInplaceInferer, {"X", "Out"});
Q
qijun 已提交
497 498 499 500
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
501
namespace plat = paddle::platform;
502

503 504
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
505 506 507
      KERNEL_TYPE,                                                          \
      ops::ActivationOp,                                                    \
      ops::OP_NAME##OpMaker,                                                \
508
      ops::ActivationOpInferVarType,                                        \
H
hong 已提交
509 510 511 512
      ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(),       \
                                 paddle::framework::OpDesc>,                \
      ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(),       \
                                 paddle::imperative::OpBase>,               \
513
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
514 515 516 517
                       ops::ActFwdInplaceInferer,                           \
                       void>::type);                                        \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad,                                     \
                    ops::ActivationOpGrad,                                  \
518
                    ops::ActivationGradOpInplaceInferer);
519

L
Leo Chen 已提交
520 521 522 523 524 525 526 527 528 529
#define REGISTER_ACTIVATION_CPU_KERNEL(                                     \
    act_type, op_name, functor, grad_functor)                               \
  REGISTER_OP_CPU_KERNEL(                                                   \
      act_type,                                                             \
      ops::ActivationKernel<phi::CPUContext, ops::functor<float>>,          \
      ops::ActivationKernel<phi::CPUContext, ops::functor<double>>);        \
  REGISTER_OP_CPU_KERNEL(                                                   \
      act_type##_grad,                                                      \
      ops::ActivationGradKernel<phi::CPUContext, ops::grad_functor<float>>, \
      ops::ActivationGradKernel<phi::CPUContext, ops::grad_functor<double>>);
530

531 532
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
533

534
REGISTER_ACTIVATION_OP(brelu, BRelu, BReluFunctor, BReluGradFunctor);
535
REGISTER_ACTIVATION_OP(relu6, Relu6, Relu6Functor, Relu6GradFunctor);
536 537
REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor);
REGISTER_ACTIVATION_OP(stanh, STanh, STanhFunctor, STanhGradFunctor);
538 539 540
REGISTER_ACTIVATION_OP(hard_swish,
                       HardSwish,
                       HardSwishFunctor,
Y
YuanRisheng 已提交
541 542
                       HardSwishGradFunctor);
REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor);
543

544 545 546
/* ==========================   pow register  ============================ */

REGISTER_OPERATOR(
547 548 549 550
    pow,
    ops::PowOp,
    ops::PowOpMaker,
    ops::ActivationOpInferVarType,
H
hong 已提交
551 552
    ops::PowGradOpMaker<paddle::framework::OpDesc>,
    ops::PowGradOpMaker<paddle::imperative::OpBase>,
553
    std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(),
554 555 556 557
                     ops::ActFwdInplaceInferer,
                     void>::type);
REGISTER_OPERATOR(pow_grad,
                  ops::PowOpGrad,
558
                  ops::ActivationGradOpInplaceInferer);
559 560
/* ========================================================================== */

561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579
/* ==========================  register checkpoint ===========================*/
REGISTER_OP_VERSION(leaky_relu)
    .AddCheckpoint(
        R"ROC(fix leaky_relu, bahavior changed when alpha < 0 or alpha > 1)ROC",
        paddle::framework::compatible::OpVersionDesc()
            .BugfixWithBehaviorChanged(
                "leaky_relu calculate formula before checkponit: out = max(x, "
                "alpha * x); after checkpoint: out = x if x > 0 else alpha * "
                "x"));

REGISTER_OP_VERSION(hard_shrink)
    .AddCheckpoint(
        R"ROC(fix hard_shrink, bahavior changed when threshold<0)ROC",
        paddle::framework::compatible::OpVersionDesc()
            .BugfixWithBehaviorChanged(
                "hard_shrink calculate formula before checkponit: out = x * "
                "((x < -threshold) + (x > threshold)); after checkpoint: out = "
                "x * (((x < -threshold) + (x > threshold)) > 0)"));

580 581
REGISTER_OP_VERSION(softplus).AddCheckpoint(
    R"ROC(add new attributes [beta] and [threshold], and the formula is changed to "
582 583
         " softplus(x) = \\frac{1}{beta} * \\log(1 + e^{beta * x}) \\\\ \\text{For numerical"
         " stability, the implementation reverts to the linear function when: beta * x > threshold.})ROC",
584 585 586 587 588 589 590
    paddle::framework::compatible::OpVersionDesc()
        .NewAttr("beta", "The beta value of the new formula", 1.0f)
        .NewAttr("threshold", "The threshold value of the new formula", 20.0f));

REGISTER_OP_VERSION(mish).AddCheckpoint(
    R"ROC(add new attributes [use_mkldnn], and when computing softplus the formula is changed as the new veriosn of softplus)ROC",
    paddle::framework::compatible::OpVersionDesc().NewAttr(
591 592
        "use_mkldnn",
        "(bool, default false) Only used in mkldnn kernel",
593
        false));
594

595
/* ========================================================================== */