activation_op.cc 18.9 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Q
qijun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Q
qijun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/activation_op.h"
16

T
tink2123 已提交
17
#include <memory>
D
dzhwinter 已提交
18
#include <string>
19
#include <type_traits>
T
tink2123 已提交
20
#include <unordered_map>
21
#include <vector>
22

C
Charles-hit 已提交
23
#include "paddle/fluid/framework/infershape_utils.h"
24
#include "paddle/fluid/framework/op_version_registry.h"
25
#include "paddle/fluid/operators/common_infer_shape_functions.h"
J
Jiabin Yang 已提交
26 27 28
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
29
#include "paddle/phi/backends/dynload/port.h"
30
#include "paddle/phi/core/kernel_registry.h"
C
Charles-hit 已提交
31
#include "paddle/phi/infermeta/backward.h"
A
Adam 已提交
32 33
DECLARE_bool(use_mkldnn);

Q
qijun 已提交
34 35 36
namespace paddle {
namespace operators {

37 38
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
39 40
  return GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kDepOut ||
         GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kNoDeps;
41 42
}

43 44 45 46 47 48 49 50 51 52 53 54 55 56
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT)           \
  class OP_NAME##OpMaker                                            \
      : public ::paddle::framework::OpProtoAndCheckerMaker {        \
   public:                                                          \
    void Make() override {                                          \
      AddInput("X",                                                 \
               "Input of " #OP_NAME                                 \
               " operator, an N-D Tensor, with data type float32, " \
               "float64 or float16.");                              \
      AddOutput("Out",                                              \
                "Output of " #OP_NAME                               \
                " operator, a Tensor with shape same as input.");   \
      AddComment(OP_COMMENT);                                       \
    }                                                               \
D
dzhwinter 已提交
57
  }
D
dzhwinter 已提交
58

H
hong 已提交
59 60
template <ActBwdOpFwdDeps kDepValue, typename T>
class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
61
 public:
H
hong 已提交
62
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
63 64

 protected:
65
  void Apply(GradOpPtr<T> op) const override {
H
hong 已提交
66 67 68 69
    op->SetType(this->ForwardOpType() + "_grad");
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
70

A
Adam 已提交
71 72
    if ((static_cast<int>(kDepValue) &
         static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
73 74
        FLAGS_use_mkldnn ||
        (op->HasAttr("use_mkldnn") &&
R
Ruibiao Chen 已提交
75
         PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")))) {
76
      op->SetInput("X", this->Input("X"));  // x
77 78 79 80
    }

    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
81
      op->SetInput("Out", this->Output("Out"));  // out
82
    }
D
dzhwinter 已提交
83
  }
84
};
J
Jiabin Yang 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
class HardSwishCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
 public:
  using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;

 protected:
  void Apply() override {
    paddle::Tensor x = this->GetSingleForwardInput("X");
    paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
    paddle::Tensor dx = this->GetSingleInputGrad("X");
    auto* dx_ptr = this->GetOutputPtr(&dx);
    std::string dx_name = this->GetOutputName(dx);
    VLOG(6) << "Runing hardswish_grad composite func";
    prim::hardswish_grad<prim::DescTensor>(x, out_grad, dx_ptr);
    this->RecoverOutputName(dx, dx_name);
  }
};
D
dzhwinter 已提交
101

102 103 104
phi::KernelKey GetKernelType(const framework::ExecutionContext& ctx,
                             const framework::OperatorWithKernel& oper,
                             const std::string& name) {
105
  auto data_type = oper.IndicateVarDataType(ctx, name);
106 107 108 109 110 111 112 113 114 115
  // FIXME(liuwei1031) temporarily disable the code to unblock users
  // TODO(liuwei1031) figure out the reason behind
  // https://github.com/PaddlePaddle/Paddle/issues/16096
  // and re-enable this in the future
  // #ifdef PADDLE_WITH_CUDA
  //   auto it1 = oper.Attrs().find("use_cudnn");
  //   if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
  //     library = framework::LibraryType::kCUDNN;
  //   }
  // #endif
116
  return phi::KernelKey(data_type, ctx.GetPlace());
117 118
}

Q
qijun 已提交
119 120 121 122
class ActivationOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

123
  void InferShape(framework::InferShapeContext* ctx) const override {
124
    ctx->ShareDim("X", /*->*/ "Out");
F
fengjiayi 已提交
125
    ctx->ShareLoD("X", /*->*/ "Out");
Q
qijun 已提交
126
  }
127

128
 protected:
129
  phi::KernelKey GetExpectedKernelType(
130 131 132
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "X");
  }
Q
qijun 已提交
133 134
};

C
chengduo 已提交
135 136 137
class ActivationOpInferVarType
    : public framework::PassInDtypeAndVarTypeToOutput {
 protected:
138
  std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
C
chengduo 已提交
139
      const override {
140 141
    static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
    return m;
142 143 144
  }
};

Q
qijun 已提交
145 146 147 148
class ActivationOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

149
  void InferShape(framework::InferShapeContext* ctx) const override {
150 151 152
    auto out_grad_name = framework::GradVarName("Out");
    ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
    ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
Q
qijun 已提交
153
  }
154

155
 protected:
156
  phi::KernelKey GetExpectedKernelType(
157
      const framework::ExecutionContext& ctx) const override {
158
    return GetKernelType(ctx, *this, framework::GradVarName("Out"));
159
  }
Q
qijun 已提交
160 161
};

162 163
class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
164
  void Make() override {
165
    AddInput("X", "Input of SoftRelu operator");
F
fengjiayi 已提交
166
    AddOutput("Out", "Output of SoftRelu operator");
167 168
    AddAttr<float>("threshold", "The threshold value of SoftRelu")
        .SetDefault(40.0f);
K
Kexin Zhao 已提交
169
    AddComment(R"DOC(
K
kexinzhao 已提交
170
SoftRelu Activation Operator.
K
Kexin Zhao 已提交
171

172
$$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$$
K
Kexin Zhao 已提交
173 174

)DOC");
175 176 177
  }
};

178 179
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
180
  void Make() override {
Z
zhupengyang 已提交
181 182 183 184 185 186 187 188
    AddInput("X",
             "Input of relu6 operator, an N-D Tensor, "
             "with data type float32, float64.");
    AddOutput(
        "Out",
        "Output of relu6 operator, a Tensor with the same shape as input.");
    AddAttr<float>("threshold",
                   "The threshold value of Relu6. Default is 6.0. ")
189
        .SetDefault(6.0f);
K
Kexin Zhao 已提交
190
    AddComment(R"DOC(
K
kexinzhao 已提交
191
Relu6 Activation Operator.
K
Kexin Zhao 已提交
192

193
$$out = \min(\max(0, x), threshold)$$
K
Kexin Zhao 已提交
194 195

)DOC");
196 197 198
  }
};

A
Abhinav Arora 已提交
199 200
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
201
  void Make() override {
A
Abhinav Arora 已提交
202
    AddInput("X", "Input of Swish operator");
F
fengjiayi 已提交
203
    AddOutput("Out", "Output of Swish operator");
A
Abhinav Arora 已提交
204 205 206 207
    AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
    AddComment(R"DOC(
Swish Activation Operator.

208
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
A
Abhinav Arora 已提交
209 210 211 212 213

)DOC");
  }
};

214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
class MishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of Mish operator");
    AddOutput("Out", "Output of Mish operator");
    AddAttr<float>(
        "threshold",
        "Constant threshold of softplus in Mish operator. Approximate value "
        "of softplus will be used if absolute value of input is greater than "
        ":attr:`threshold`")
        .SetDefault(20.f);
    AddComment(R"DOC(
Mish Activation Operator.

..  math::
    softplus(x) = \begin{cases}
            x, \text{if } x > \text{threshold} \\
            \ln(1 + e^{x}),  \text{otherwise}
          \end{cases}

    out = x * \tanh(softplus(x))

)DOC");
  }
};

H
huangjun12 已提交
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Input of HardSwish operator");
    AddOutput("Out", "Output of HardSwish operator");
    AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("scale", "The scale parameter of HardSwish operator")
        .SetDefault(6.0f);
    AddAttr<float>("offset", "The offset parameter of HardSwish operator")
        .SetDefault(3.0f);
    AddComment(R"DOC(
HardSwish Activation Operator.

The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).

256
$$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$$
H
huangjun12 已提交
257 258 259 260 261 262 263 264 265

The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.

)DOC");
  }
};

266
template <ActBwdOpFwdDeps kDepValue>
267 268 269 270 271
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
272 273
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
274
      if (ctx->HasOutput("DX")) {
275 276 277
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
278
      if (ctx->HasOutput("DDOut")) {
279 280 281
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
282
    }
283 284
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
285
      if (ctx->HasOutput("DOut")) {
286 287 288
        ctx->ShareDim("Out", "DOut");
        ctx->ShareLoD("Out", "DOut");
      }
289 290 291 292
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
293 294 295 296
      if (ctx->HasOutput("DOutNew")) {
        ctx->ShareDim("Out", "DOutNew");
        ctx->ShareLoD("Out", "DOutNew");
      }
297 298 299 300
    }
  }

 protected:
301
  phi::KernelKey GetExpectedKernelType(
302 303 304 305 306 307 308 309 310 311 312
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
313 314
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
315 316 317 318 319
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
320 321
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
322
      if (ctx->HasOutput("DDOut")) {
323 324 325
        ctx->ShareDim("Out", "DDOut");
        ctx->ShareLoD("Out", "DDOut");
      }
326 327 328 329
    }
  }

 protected:
330
  phi::KernelKey GetExpectedKernelType(
331 332 333 334 335
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

336 337 338 339 340 341
template <ActBwdOpFwdDeps kDepValue>
class ActivationOpTripleGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
342 343
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
344 345 346 347 348 349 350 351 352
      if (ctx->HasOutput("DX")) {
        ctx->ShareDim("X", "DX");
        ctx->ShareLoD("X", "DX");
      }
      if (ctx->HasOutput("DDOut")) {
        ctx->ShareDim("X", "DDOut");
        ctx->ShareLoD("X", "DDOut");
      }
    }
353 354
    if (static_cast<int>(kDepValue) &
        static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370
      if (ctx->HasOutput("D_DOut")) {
        ctx->ShareDim("Out", "D_DOut");
        ctx->ShareLoD("Out", "D_DOut");
      }
      if (ctx->HasOutput("D_OutNew")) {
        ctx->ShareDim("Out", "D_OutNew");
        ctx->ShareLoD("Out", "D_OutNew");
      }
      if (ctx->HasOutput("D_DDx")) {
        ctx->ShareDim("DDX", "D_DDx");
        ctx->ShareLoD("DDX", "D_DDx");
      }
    }
  }

 protected:
371
  phi::KernelKey GetExpectedKernelType(
372 373 374 375 376
      const framework::ExecutionContext& ctx) const override {
    return GetKernelType(ctx, *this, "DDX");
  }
};

377
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInferer,
378 379
                           {framework::GradVarName("Out"),  // dout
                            framework::GradVarName("X")});  // dx
380
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInferer,
381
                           {"DDX", "DDOut"});
382 383
DECLARE_INPLACE_OP_INFERER(ActivationTripleGradOpInplaceInferer,
                           {"DDX", "D_DOut"});
384

385
DECLARE_INPLACE_OP_INFERER(ActFwdInplaceInferer, {"X", "Out"});
386 387 388 389 390 391 392 393 394 395 396 397

#define DEFINE_ACTIVATION_CPU_KERNEL(op_name, functor, grad_functor)           \
  template <typename T, typename DeviceContext>                                \
  class op_name##Kernel : public ActivationKernel<DeviceContext, functor<T>> { \
  };                                                                           \
                                                                               \
  template <typename T, typename DeviceContext>                                \
  class op_name##GradKernel                                                    \
      : public ActivationGradKernel<DeviceContext, grad_functor<T>> {};

DEFINE_ACTIVATION_CPU_KERNEL(SoftRelu, SoftReluFunctor, SoftReluGradFunctor)

Q
qijun 已提交
398 399 400 401
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
402
namespace plat = paddle::platform;
403

404 405
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
  REGISTER_OPERATOR(                                                        \
406 407 408
      KERNEL_TYPE,                                                          \
      ops::ActivationOp,                                                    \
      ops::OP_NAME##OpMaker,                                                \
409
      ops::ActivationOpInferVarType,                                        \
H
hong 已提交
410 411 412 413
      ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(),       \
                                 paddle::framework::OpDesc>,                \
      ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(),       \
                                 paddle::imperative::OpBase>,               \
414
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(),      \
415 416 417 418
                       ops::ActFwdInplaceInferer,                           \
                       void>::type);                                        \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad,                                     \
                    ops::ActivationOpGrad,                                  \
419
                    ops::ActivationGradOpInplaceInferer);
420

J
Jiabin Yang 已提交
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439
#define REGISTER_ACTIVATION_OP_WITH_COMP(                              \
    KERNEL_TYPE, OP_NAME, functor, grad_functor)                       \
  REGISTER_OPERATOR(                                                   \
      KERNEL_TYPE,                                                     \
      ops::ActivationOp,                                               \
      ops::OP_NAME##OpMaker,                                           \
      ops::ActivationOpInferVarType,                                   \
      ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(),  \
                                 paddle::framework::OpDesc>,           \
      ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(),  \
                                 paddle::imperative::OpBase>,          \
      ops::OP_NAME##CompositeGradOpMaker,                              \
      std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
                       ops::ActFwdInplaceInferer,                      \
                       void>::type);                                   \
  REGISTER_OPERATOR(KERNEL_TYPE##_grad,                                \
                    ops::ActivationOpGrad,                             \
                    ops::ActivationGradOpInplaceInferer);

440
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
441 442 443 444 445 446 447 448 449 450 451 452

#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name)                \
  PD_REGISTER_STRUCT_KERNEL(                                             \
      act_type, CPU, ALL_LAYOUT, ops::op_name##Kernel, float, double) {} \
  PD_REGISTER_STRUCT_KERNEL(act_type##_grad,                             \
                            CPU,                                         \
                            ALL_LAYOUT,                                  \
                            ops::op_name##GradKernel,                    \
                            float,                                       \
                            double) {}

REGISTER_ACTIVATION_CPU_KERNEL(soft_relu, SoftRelu)
453

454
REGISTER_ACTIVATION_OP(relu6, Relu6, Relu6Functor, Relu6GradFunctor);
455
REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor);
J
Jiabin Yang 已提交
456 457 458 459
REGISTER_ACTIVATION_OP_WITH_COMP(hard_swish,
                                 HardSwish,
                                 HardSwishFunctor,
                                 HardSwishGradFunctor);
Y
YuanRisheng 已提交
460
REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor);
461

462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480
/* ==========================  register checkpoint ===========================*/
REGISTER_OP_VERSION(leaky_relu)
    .AddCheckpoint(
        R"ROC(fix leaky_relu, bahavior changed when alpha < 0 or alpha > 1)ROC",
        paddle::framework::compatible::OpVersionDesc()
            .BugfixWithBehaviorChanged(
                "leaky_relu calculate formula before checkponit: out = max(x, "
                "alpha * x); after checkpoint: out = x if x > 0 else alpha * "
                "x"));

REGISTER_OP_VERSION(hard_shrink)
    .AddCheckpoint(
        R"ROC(fix hard_shrink, bahavior changed when threshold<0)ROC",
        paddle::framework::compatible::OpVersionDesc()
            .BugfixWithBehaviorChanged(
                "hard_shrink calculate formula before checkponit: out = x * "
                "((x < -threshold) + (x > threshold)); after checkpoint: out = "
                "x * (((x < -threshold) + (x > threshold)) > 0)"));

481 482
REGISTER_OP_VERSION(softplus).AddCheckpoint(
    R"ROC(add new attributes [beta] and [threshold], and the formula is changed to "
483 484
         " softplus(x) = \\frac{1}{beta} * \\log(1 + e^{beta * x}) \\\\ \\text{For numerical"
         " stability, the implementation reverts to the linear function when: beta * x > threshold.})ROC",
485 486 487 488 489 490 491
    paddle::framework::compatible::OpVersionDesc()
        .NewAttr("beta", "The beta value of the new formula", 1.0f)
        .NewAttr("threshold", "The threshold value of the new formula", 20.0f));

REGISTER_OP_VERSION(mish).AddCheckpoint(
    R"ROC(add new attributes [use_mkldnn], and when computing softplus the formula is changed as the new veriosn of softplus)ROC",
    paddle::framework::compatible::OpVersionDesc().NewAttr(
492 493
        "use_mkldnn",
        "(bool, default false) Only used in mkldnn kernel",
494
        false));
495

496
/* ========================================================================== */