activation_op.cc 16.8 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
16

T
tink2123 已提交
17
#include <memory>
D
dzhwinter 已提交
18
#include <string>
19
#include <type_traits>
T
tink2123 已提交
20
#include <unordered_map>
21
#include <vector>
22

C
Charles-hit 已提交
23
#include "paddle/fluid/framework/infershape_utils.h"
24
#include "paddle/fluid/framework/op_version_registry.h"
25
#include "paddle/fluid/operators/common_infer_shape_functions.h"
26
#include "paddle/phi/backends/dynload/port.h"
C
Charles-hit 已提交
27
#include "paddle/phi/infermeta/backward.h"
Q
qijun 已提交
28

A
Adam 已提交
29 30
DECLARE_bool(use_mkldnn);

Q
qijun 已提交
31 32 33
namespace paddle {
namespace operators {

34 35
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
36 37
  return GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kDepOut ||
         GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kNoDeps;
38 39
}

40 41 42 43 44 45 46 47 48 49 50 51 52 53
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)           \
  class OP_NAME##OpMaker                                            \
      : public ::paddle::framework::OpProtoAndCheckerMaker {        \
   public:                                                          \
    void Make() override {                                          \
      AddInput("X",                                                 \
               "Input of " #OP_NAME                                 \
               " operator, an N-D Tensor, with data type float32, " \
               "float64 or float16.");                              \
      AddOutput("Out",                                              \
                "Output of " #OP_NAME                               \
                " operator, a Tensor with shape same as input.");   \
      AddComment(OP_COMMENT);                                       \
    }                                                               \
D
dzhwinter 已提交
54
  }
D
dzhwinter 已提交
55

H
hong 已提交
56 57
template <ActBwdOpFwdDeps kDepValue, typename T>
class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
58
 public:
H
hong 已提交
59
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
60 61

 protected:
62
  void Apply(GradOpPtr<T> op) const override {
H
hong 已提交
63 64 65 66
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
67

A
Adam 已提交
68 69
    if ((static_cast<int>(kDepValue) &
         static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
70 71
        FLAGS_use_mkldnn ||
        (op->HasAttr("use_mkldnn") &&
R
Ruibiao Chen 已提交
72
         PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")))) {
73
      op->SetInput("X", this->Input("X"));  // x
74 75 76 77
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
78
      op->SetInput("Out", this->Output("Out"));  // out
79
    }
D
dzhwinter 已提交
80
  }
81
};
D
dzhwinter 已提交
82

83 84 85
phi::KernelKey GetKernelType(const framework::ExecutionContext& ctx,
                             const framework::OperatorWithKernel& oper,
                             const std::string& name) {
86
  auto data_type = oper.IndicateVarDataType(ctx, name);
87 88 89 90 91 92 93 94 95 96
  // FIXME(liuwei1031) temporarily disable the code to unblock users
  // TODO(liuwei1031) figure out the reason behind
  // https://github.com/PaddlePaddle/Paddle/issues/16096
  // and re-enable this in the future
  // #ifdef PADDLE_WITH_CUDA
  //   auto it1 = oper.Attrs().find("use_cudnn");
  //   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
  //     library = framework::LibraryType::kCUDNN;
  //   }
  // #endif
97
  return phi::KernelKey(data_type, ctx.GetPlace());
98 99
}

Q
qijun 已提交
100 101 102 103
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

104
  void InferShape(framework::InferShapeContext* ctx) const override {
105
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
106
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
107
  }
108

109
 protected:
110
  phi::KernelKey GetExpectedKernelType(
111 112 113
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
114 115
};

C
chengduo 已提交
116 117 118
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
119
  std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
C
chengduo 已提交
120
      const override {
121 122
    static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
    return m;
123 124 125
  }
};

Q
qijun 已提交
126 127 128 129
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

130
  void InferShape(framework::InferShapeContext* ctx) const override {
131 132 133
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
134
  }
135

136
 protected:
137
  phi::KernelKey GetExpectedKernelType(
138
      const framework::ExecutionContext& ctx) const override {
139
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
140
  }
Q
qijun 已提交
141 142
};

143 144
class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
145
  void Make() override {
146
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
147
    AddOutput("Out", "Output of SoftRelu operator");
148 149
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
150
    AddComment(R"DOC(
K
kexinzhao 已提交
151
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
152

153
$$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$$
K
Kexin Zhao 已提交
154 155

)DOC");
156 157 158
  }
};

159 160
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
161
  void Make() override {
Z
zhupengyang 已提交
162 163 164 165 166 167 168 169
    AddInput("X",
             "Input of relu6 operator, an N-D Tensor, "
             "with data type float32, float64.");
    AddOutput(
        "Out",
        "Output of relu6 operator, a Tensor with the same shape as input.");
    AddAttr<float>("threshold",
                   "The threshold value of Relu6. Default is 6.0. ")
170
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
171
    AddComment(R"DOC(
K
kexinzhao 已提交
172
Relu6 Activation Operator.
K
Kexin Zhao 已提交
173

174
$$out = \min(\max(0, x), threshold)$$
K
Kexin Zhao 已提交
175 176

)DOC");
177 178 179
  }
};

180 181
class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
182
  void Make() override {
183 184
    AddInput("X",
             "Input of STanh operator."
N
Noel 已提交
185
             " A Tensor with type float32, float64.");
186 187 188
    AddOutput("Out", "Output of STanh operator. A Tensor with type float32.");
    AddAttr<float>("scale_a", "The scale parameter of a for the input. ")
        .SetDefault(0.67f);
189 190
    AddAttr<float>("scale_b", "The scale parameter of b for the input")
        .SetDefault(1.7159f);
K
Kexin Zhao 已提交
191
    AddComment(R"DOC(
K
kexinzhao 已提交
192
STanh Activation Operator.
K
Kexin Zhao 已提交
193

Y
Yan Chunwei 已提交
194
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
K
Kexin Zhao 已提交
195 196

)DOC");
Q
qijun 已提交
197 198 199
  }
};

A
Abhinav Arora 已提交
200 201
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
202
  void Make() override {
A
Abhinav Arora 已提交
203
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
204
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
205 206 207 208
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

209
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
A
Abhinav Arora 已提交
210 211 212 213 214

)DOC");
  }
};

215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
class MishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of Mish operator");
    AddOutput("Out", "Output of Mish operator");
    AddAttr<float>(
        "threshold",
        "Constant threshold of softplus in Mish operator. Approximate value "
        "of softplus will be used if absolute value of input is greater than "
        ":attr:`threshold`")
        .SetDefault(20.f);
    AddComment(R"DOC(
Mish Activation Operator.

..  math::
    softplus(x) = \begin{cases}
            x, \text{if } x > \text{threshold} \\
            \ln(1 + e^{x}),  \text{otherwise}
          \end{cases}

    out = x * \tanh(softplus(x))

)DOC");
  }
};

H
huangjun12 已提交
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of HardSwish operator");
    AddOutput("Out", "Output of HardSwish operator");
    AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("scale", "The scale parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("offset", "The offset parameter of HardSwish operator")
        .SetDefault(3.0f);
    AddComment(R"DOC(
HardSwish Activation Operator.

The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).

257
$$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$$
H
huangjun12 已提交
258 259 260 261 262 263 264 265 266

The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.

)DOC");
  }
};

267
template <ActBwdOpFwdDeps kDepValue>
268 269 270 271 272
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
273 274
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
275
      if (ctx->HasOutput("DX")) {
276 277 278
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
279
      if (ctx->HasOutput("DDOut")) {
280 281 282
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
283
    }
284 285
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
286
      if (ctx->HasOutput("DOut")) {
287 288 289
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
290 291 292 293
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
294 295 296 297
      if (ctx->HasOutput("DOutNew")) {
        ctx->ShareDim("Out", "DOutNew");
        ctx->ShareLoD("Out", "DOutNew");
      }
298 299 300 301
    }
  }

 protected:
302
  phi::KernelKey GetExpectedKernelType(
303 304 305 306 307 308 309 310 311 312 313
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
314 315
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
316 317 318 319 320
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
321 322
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
323
      if (ctx->HasOutput("DDOut")) {
324 325 326
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
327 328 329 330
    }
  }

 protected:
331
  phi::KernelKey GetExpectedKernelType(
332 333 334 335 336
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

337 338 339 340 341 342
template <ActBwdOpFwdDeps kDepValue>
class ActivationOpTripleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
343 344
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
345 346 347 348 349 350 351 352 353
      if (ctx->HasOutput("DX")) {
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
354 355
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
      if (ctx->HasOutput("D_DOut")) {
        ctx->ShareDim("Out", "D_DOut");
        ctx->ShareLoD("Out", "D_DOut");
      }
      if (ctx->HasOutput("D_OutNew")) {
        ctx->ShareDim("Out", "D_OutNew");
        ctx->ShareLoD("Out", "D_OutNew");
      }
      if (ctx->HasOutput("D_DDx")) {
        ctx->ShareDim("DDX", "D_DDx");
        ctx->ShareLoD("DDX", "D_DDx");
      }
    }
  }

 protected:
372
  phi::KernelKey GetExpectedKernelType(
373 374 375 376 377
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

378
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInferer,
379 380
                           {framework::GradVarName("Out"),  // dout
                            framework::GradVarName("X")});  // dx
381
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInferer,
382
                           {"DDX", "DDOut"});
383 384
DECLARE_INPLACE_OP_INFERER(ActivationTripleGradOpInplaceInferer,
                           {"DDX", "D_DOut"});
385

386
DECLARE_INPLACE_OP_INFERER(ActFwdInplaceInferer, {"X", "Out"});
Q
qijun 已提交
387 388 389 390
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
391
namespace plat = paddle::platform;
392

393 394
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
395 396 397
      KERNEL_TYPE,                                                          \
      ops::ActivationOp,                                                    \
      ops::OP_NAME##OpMaker,                                                \
398
      ops::ActivationOpInferVarType,                                        \
H
hong 已提交
399 400 401 402
      ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(),       \
                                 paddle::framework::OpDesc>,                \
      ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(),       \
                                 paddle::imperative::OpBase>,               \
403
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
404 405 406 407
                       ops::ActFwdInplaceInferer,                           \
                       void>::type);                                        \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad,                                     \
                    ops::ActivationOpGrad,                                  \
408
                    ops::ActivationGradOpInplaceInferer);
409

L
Leo Chen 已提交
410 411 412 413 414 415 416 417 418 419
#define REGISTER_ACTIVATION_CPU_KERNEL(                                     \
    act_type, op_name, functor, grad_functor)                               \
  REGISTER_OP_CPU_KERNEL(                                                   \
      act_type,                                                             \
      ops::ActivationKernel<phi::CPUContext, ops::functor<float>>,          \
      ops::ActivationKernel<phi::CPUContext, ops::functor<double>>);        \
  REGISTER_OP_CPU_KERNEL(                                                   \
      act_type##_grad,                                                      \
      ops::ActivationGradKernel<phi::CPUContext, ops::grad_functor<float>>, \
      ops::ActivationGradKernel<phi::CPUContext, ops::grad_functor<double>>);
420

421 422
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
423

424
REGISTER_ACTIVATION_OP(relu6, Relu6, Relu6Functor, Relu6GradFunctor);
425 426
REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor);
REGISTER_ACTIVATION_OP(stanh, STanh, STanhFunctor, STanhGradFunctor);
427 428 429
REGISTER_ACTIVATION_OP(hard_swish,
                       HardSwish,
                       HardSwishFunctor,
Y
YuanRisheng 已提交
430 431
                       HardSwishGradFunctor);
REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor);
432

433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
/* ==========================  register checkpoint ===========================*/
REGISTER_OP_VERSION(leaky_relu)
    .AddCheckpoint(
        R"ROC(fix leaky_relu, bahavior changed when alpha < 0 or alpha > 1)ROC",
        paddle::framework::compatible::OpVersionDesc()
            .BugfixWithBehaviorChanged(
                "leaky_relu calculate formula before checkponit: out = max(x, "
                "alpha * x); after checkpoint: out = x if x > 0 else alpha * "
                "x"));

REGISTER_OP_VERSION(hard_shrink)
    .AddCheckpoint(
        R"ROC(fix hard_shrink, bahavior changed when threshold<0)ROC",
        paddle::framework::compatible::OpVersionDesc()
            .BugfixWithBehaviorChanged(
                "hard_shrink calculate formula before checkponit: out = x * "
                "((x < -threshold) + (x > threshold)); after checkpoint: out = "
                "x * (((x < -threshold) + (x > threshold)) > 0)"));

452 453
REGISTER_OP_VERSION(softplus).AddCheckpoint(
    R"ROC(add new attributes [beta] and [threshold], and the formula is changed to "
454 455
         " softplus(x) = \\frac{1}{beta} * \\log(1 + e^{beta * x}) \\\\ \\text{For numerical"
         " stability, the implementation reverts to the linear function when: beta * x > threshold.})ROC",
456 457 458 459 460 461 462
    paddle::framework::compatible::OpVersionDesc()
        .NewAttr("beta", "The beta value of the new formula", 1.0f)
        .NewAttr("threshold", "The threshold value of the new formula", 20.0f));

REGISTER_OP_VERSION(mish).AddCheckpoint(
    R"ROC(add new attributes [use_mkldnn], and when computing softplus the formula is changed as the new veriosn of softplus)ROC",
    paddle::framework::compatible::OpVersionDesc().NewAttr(
463 464
        "use_mkldnn",
        "(bool, default false) Only used in mkldnn kernel",
465
        false));
466

467
/* ========================================================================== */