activation_mkldnn_op.cc 13.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/fluid/operators/activation_op.h"
16
#include "paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h"
17
#include "paddle/fluid/platform/mkldnn_reuse.h"
18

19 20 21 22
namespace pten {
class DenseTensor;
}  // namespace pten

W
wanghuancoder 已提交
23
namespace paddle {
24
namespace framework {}  // namespace framework
W
wanghuancoder 已提交
25 26 27 28 29
namespace platform {
class MKLDNNDeviceContext;
}  // namespace platform
}  // namespace paddle

30 31 32
namespace paddle {
namespace operators {

33 34
using framework::DataLayout;
using framework::Tensor;
35 36 37
using dnnl::memory;
using dnnl::primitive;
using dnnl::stream;
38 39 40
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
41

42 43 44 45 46 47
template <typename Functor>
class MKLDNNActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *x = ctx.Input<Tensor>("X");
48 49 50 51 52 53
    PADDLE_ENFORCE_EQ(
        x->layout(), DataLayout::kMKLDNN,
        platform::errors::InvalidArgument("Wrong layout set for X tensor"));
    PADDLE_ENFORCE_NE(
        x->format(), MKLDNNMemoryFormat::undef,
        platform::errors::InvalidArgument("Wrong format set for X tensor"));
54 55 56 57 58

    Functor functor;
    functor(ctx);
  }
};
K
Krzysztof Binias 已提交
59

60 61 62 63 64 65
template <typename Functor>
class MKLDNNActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
66
    PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
67 68
                      platform::errors::InvalidArgument(
                          "Wrong layout set for Input OutGrad tensor"));
A
Adam 已提交
69
    PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
70 71
                      platform::errors::InvalidArgument(
                          "Wrong format set for Input OutGrad tensor"));
72 73 74 75 76 77 78 79

    Functor functor;
    functor(ctx);
  }
};

template <typename T>
void eltwise_forward(const framework::ExecutionContext &ctx,
80
                     dnnl::algorithm algorithm) {
81 82 83
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
                    paddle::platform::errors::PreconditionNotMet(
                        "Operator DNNL eletwise_forward must use CPUPlace"));
84
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
85
  const auto &mkldnn_engine = dev_ctx.GetEngine();
86

87
  const auto *x = ctx.Input<Tensor>("X");
88
  auto *out = ctx.Output<Tensor>("Out");
89

90
  bool is_inplaced = x->IsSharedBufferWith(*out);
91

92 93
  platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
                                               ctx.GetPlace(), x);
94

95
  auto src_memory_p = handler.AcquireSrcMemory(x);
96 97 98
  std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
  if (is_inplaced) {
    dst_memory_p = src_memory_p;
99
    out->mutable_data<T>(ctx.GetPlace());
100
  } else {
101
    dst_memory_p = handler.AcquireDstMemory(out);
102
  }
A
Adam 已提交
103
  auto activation_p = handler.AcquireForwardPrimitive();
104

105
  auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
106 107
  activation_p->execute(
      astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}});
A
Adam 已提交
108
  astream.wait();
109

110 111
  out->set_layout(DataLayout::kMKLDNN);
  out->set_format(GetMKLDNNFormat(*dst_memory_p));
112 113
}

114 115
template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
116
                  dnnl::algorithm algorithm) {
117
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
118
  const auto &mkldnn_engine = dev_ctx.GetEngine();
119

120
  const auto *x = ctx.Input<Tensor>("X");
121 122
  const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
  auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
123

124
  platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
125
                                               ctx.GetPlace(), x, dout);
126

127
  auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
128 129
  auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
  auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
A
Adam 已提交
130 131
  auto activation_backward_p = handler.AcquireBackwardPrimitive();

132
  auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
A
Adam 已提交
133
  activation_backward_p->execute(astream,
134 135 136
                                 {{DNNL_ARG_SRC, *src_memory_p},
                                  {DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
                                  {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
A
Adam 已提交
137
  astream.wait();
138

139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
  dx->set_layout(DataLayout::kMKLDNN);
  dx->set_format(GetMKLDNNFormat(*diff_src_memory_p));
}

template <typename T>
void eltwise_grad_use_out(const framework::ExecutionContext &ctx,
                          dnnl::algorithm algorithm) {
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
  const auto &mkldnn_engine = dev_ctx.GetEngine();

  const auto *out = ctx.Input<Tensor>("Out");
  const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
  auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));

  platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
                                               ctx.GetPlace(), out, dout);

  auto dst_memory_p = handler.AcquireBackwardSrcMemory(out);
  auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
  auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
  auto activation_backward_p = handler.AcquireBackwardPrimitive();

  auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
  activation_backward_p->execute(astream,
                                 {{DNNL_ARG_DST, *dst_memory_p},
                                  {DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
                                  {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
  astream.wait();

  dx->set_layout(DataLayout::kMKLDNN);
  dx->set_format(GetMKLDNNFormat(*diff_src_memory_p));
170 171
}

172
template <typename T, dnnl::algorithm algorithm>
173
struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
174
  void operator()(const framework::ExecutionContext &ctx) const {
175 176 177 178
    eltwise_forward<T>(ctx, algorithm);
  }
};

179
template <typename T, dnnl::algorithm algorithm>
180
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
181
  void operator()(const framework::ExecutionContext &ctx) const {
182 183 184 185
    eltwise_grad<T>(ctx, algorithm);
  }
};

186 187 188 189 190 191 192
template <typename T, dnnl::algorithm algorithm>
struct MKLDNNActivationGradUseOutFunc : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    eltwise_grad_use_out<T>(ctx, algorithm);
  }
};

A
Adam 已提交
193 194 195 196 197
template <typename T>
struct GeluMKLDNNFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const bool approximate = ctx.Attr<bool>("approximate");
    if (approximate) {
198
      eltwise_forward<T>(ctx, dnnl::algorithm::eltwise_gelu_tanh);
A
Adam 已提交
199
    } else {
200
      eltwise_forward<T>(ctx, dnnl::algorithm::eltwise_gelu_erf);
A
Adam 已提交
201 202 203 204 205 206 207 208 209
    }
  }
};

template <typename T>
struct GeluMKLDNNGradFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const bool approximate = ctx.Attr<bool>("approximate");
    if (approximate) {
210
      eltwise_grad<T>(ctx, dnnl::algorithm::eltwise_gelu_tanh);
A
Adam 已提交
211
    } else {
212
      eltwise_grad<T>(ctx, dnnl::algorithm::eltwise_gelu_erf);
A
Adam 已提交
213 214 215 216
    }
  }
};

217 218 219 220 221 222 223
template <typename T>
struct SoftplusMKLDNNFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    custom_softplus_eltwise_forward<T>(ctx);
  }
};

224
template <typename T>
T
tensor-tang 已提交
225
using ReluMKLDNNFunctor =
226
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_relu>;
227

A
Adam 已提交
228 229
template <typename T>
using Relu6MKLDNNFunctor =
230
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_bounded_relu>;
A
Adam 已提交
231

232 233
template <typename T>
using SwishMKLDNNFunctor =
234
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_swish>;
235

J
jakpiase 已提交
236 237
template <typename T>
using HardSwishMKLDNNFunctor =
238
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_hardswish>;
J
jakpiase 已提交
239

240 241 242 243
template <typename T>
using MishMKLDNNFunctor =
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_mish>;

244 245
template <typename T>
using SigmoidMKLDNNFunctor =
246
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_logistic>;
247

248
template <typename T>
T
tensor-tang 已提交
249
using TanhMKLDNNFunctor =
250
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_tanh>;
251 252

template <typename T>
T
tensor-tang 已提交
253
using SqrtMKLDNNFunctor =
254
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_sqrt>;
255 256

template <typename T>
257
using AbsMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_abs>;
258

J
jakpiase 已提交
259
template <typename T>
260
using EluMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_elu>;
J
jakpiase 已提交
261

262 263 264
template <typename T>
using ExpMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_exp>;

265
template <typename T>
T
tensor-tang 已提交
266
using ReluMKLDNNGradFunctor =
267
    MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_relu>;
268

A
Adam 已提交
269 270
template <typename T>
using Relu6MKLDNNGradFunctor =
271
    MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_bounded_relu>;
A
Adam 已提交
272

273 274
template <typename T>
using SwishMKLDNNGradFunctor =
275
    MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_swish>;
276

J
jakpiase 已提交
277 278
template <typename T>
using HardSwishMKLDNNGradFunctor =
279
    MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_hardswish>;
J
jakpiase 已提交
280

281 282 283 284
template <typename T>
using MishMKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_mish>;

285
template <typename T>
286 287
using SigmoidMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
    T, dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>;
288

289
template <typename T>
290 291
using TanhMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
    T, dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>;
292 293

template <typename T>
294 295
using SqrtMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
    T, dnnl::algorithm::eltwise_sqrt_use_dst_for_bwd>;
296 297

template <typename T>
T
tensor-tang 已提交
298
using AbsMKLDNNGradFunctor =
299
    MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_abs>;
J
jakpiase 已提交
300 301

template <typename T>
302 303 304 305 306 307 308
using EluMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
    T, dnnl::algorithm::eltwise_elu_use_dst_for_bwd>;

template <typename T>
using ExpMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
    T, dnnl::algorithm::eltwise_exp_use_dst_for_bwd>;

309 310 311 312 313 314 315 316 317 318 319 320
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
  REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace,       \
                     ops::MKLDNNActivationKernel<ops::functor<float>>);    \
  REGISTER_OP_KERNEL(                                                      \
      act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace,               \
      ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);

321 322 323 324 325 326 327 328
#define REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(act_type, functor,             \
                                               grad_functor)                  \
  REGISTER_OP_KERNEL(                                                         \
      act_type, MKLDNN, ::paddle::platform::CPUPlace,                         \
      ops::MKLDNNActivationKernel<ops::functor<float>>,                       \
      ops::MKLDNNActivationKernel<ops::functor<paddle::platform::bfloat16>>); \
  REGISTER_OP_KERNEL(                                                         \
      act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace,                  \
329 330 331
      ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>,              \
      ops::MKLDNNActivationGradKernel<                                        \
          ops::grad_functor<paddle::platform::bfloat16>>);
332

J
jakpiase 已提交
333 334 335 336 337
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro)                            \
  __macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor);              \
  __macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor);           \
  __macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor);              \
  __macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
338
  __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradUseOutFunctor);           \
J
jakpiase 已提交
339
  __macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);                    \
340 341
  __macro(elu, EluMKLDNNFunctor, EluMKLDNNGradUseOutFunctor);              \
  __macro(exp, ExpMKLDNNFunctor, ExpMKLDNNGradUseOutFunctor);
342 343

FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
A
arlesniak 已提交
344 345
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor,
                                       ReluMKLDNNGradFunctor);
346 347
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(gelu, GeluMKLDNNFunctor,
                                       GeluMKLDNNGradFunctor);
348
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sigmoid, SigmoidMKLDNNFunctor,
349
                                       SigmoidMKLDNNGradUseOutFunctor);
J
jakpiase 已提交
350
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sqrt, SqrtMKLDNNFunctor,
351
                                       SqrtMKLDNNGradUseOutFunctor);
352 353
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(mish, MishMKLDNNFunctor,
                                       MishMKLDNNGradFunctor);
354 355 356 357 358

namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
    softplus, MKLDNN, paddle::platform::CPUPlace,
    ops::MKLDNNActivationKernel<ops::SoftplusMKLDNNFunctor<float>>);