activation_mkldnn_op.cc 13.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/fluid/operators/activation_op.h"
16
#include "paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h"
17
#include "paddle/fluid/platform/mkldnn_reuse.h"
18

19 20 21 22
namespace pten {
class DenseTensor;
}  // namespace pten

W
wanghuancoder 已提交
23
namespace paddle {
24
namespace framework {}  // namespace framework
W
wanghuancoder 已提交
25 26 27 28 29
namespace platform {
class MKLDNNDeviceContext;
}  // namespace platform
}  // namespace paddle

30 31 32
namespace paddle {
namespace operators {

33 34
using framework::DataLayout;
using framework::Tensor;
35 36 37
using dnnl::memory;
using dnnl::primitive;
using dnnl::stream;
38 39 40
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
41

42 43 44 45 46 47
template <typename Functor>
class MKLDNNActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *x = ctx.Input<Tensor>("X");
48 49 50 51 52 53
    PADDLE_ENFORCE_EQ(
        x->layout(), DataLayout::kMKLDNN,
        platform::errors::InvalidArgument("Wrong layout set for X tensor"));
    PADDLE_ENFORCE_NE(
        x->format(), MKLDNNMemoryFormat::undef,
        platform::errors::InvalidArgument("Wrong format set for X tensor"));
54 55 56 57 58

    Functor functor;
    functor(ctx);
  }
};
K
Krzysztof Binias 已提交
59

60 61 62 63 64 65
template <typename Functor>
class MKLDNNActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
66
    PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
67 68
                      platform::errors::InvalidArgument(
                          "Wrong layout set for Input OutGrad tensor"));
A
Adam 已提交
69
    PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
70 71
                      platform::errors::InvalidArgument(
                          "Wrong format set for Input OutGrad tensor"));
72 73 74 75 76 77 78 79

    Functor functor;
    functor(ctx);
  }
};

template <typename T>
void eltwise_forward(const framework::ExecutionContext &ctx,
80
                     dnnl::algorithm algorithm) {
81 82 83
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
                    paddle::platform::errors::PreconditionNotMet(
                        "Operator DNNL eletwise_forward must use CPUPlace"));
84
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
85
  const auto &mkldnn_engine = dev_ctx.GetEngine();
86

87
  const auto *x = ctx.Input<Tensor>("X");
88
  auto *out = ctx.Output<Tensor>("Out");
89

90
  bool is_inplaced = x->IsSharedBufferWith(*out);
91

92 93
  platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
                                               ctx.GetPlace(), x);
94

95
  auto src_memory_p = handler.AcquireSrcMemory(x);
96 97 98
  std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
  if (is_inplaced) {
    dst_memory_p = src_memory_p;
99
    out->mutable_data<T>(ctx.GetPlace());
100
  } else {
101
    dst_memory_p = handler.AcquireDstMemory(out);
102
  }
A
Adam 已提交
103
  auto activation_p = handler.AcquireForwardPrimitive();
104

105
  auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
106 107
  activation_p->execute(
      astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}});
A
Adam 已提交
108
  astream.wait();
109

110 111
  out->set_layout(DataLayout::kMKLDNN);
  out->set_format(GetMKLDNNFormat(*dst_memory_p));
112 113
}

114 115
template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
116
                  dnnl::algorithm algorithm) {
117
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
118
  const auto &mkldnn_engine = dev_ctx.GetEngine();
119

120
  const auto *x = ctx.Input<Tensor>("X");
121 122
  const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
  auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
123

124
  platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
125
                                               ctx.GetPlace(), x, dout);
126

127
  auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
128 129
  auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
  auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
A
Adam 已提交
130 131
  auto activation_backward_p = handler.AcquireBackwardPrimitive();

132
  auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
A
Adam 已提交
133
  activation_backward_p->execute(astream,
134 135 136
                                 {{DNNL_ARG_SRC, *src_memory_p},
                                  {DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
                                  {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
A
Adam 已提交
137
  astream.wait();
138

139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
  dx->set_layout(DataLayout::kMKLDNN);
  dx->set_format(GetMKLDNNFormat(*diff_src_memory_p));
}

template <typename T>
void eltwise_grad_use_out(const framework::ExecutionContext &ctx,
                          dnnl::algorithm algorithm) {
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
  const auto &mkldnn_engine = dev_ctx.GetEngine();

  const auto *out = ctx.Input<Tensor>("Out");
  const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
  auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));

  platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
                                               ctx.GetPlace(), out, dout);

  auto dst_memory_p = handler.AcquireBackwardSrcMemory(out);
  auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
  auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
  auto activation_backward_p = handler.AcquireBackwardPrimitive();

  auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
  activation_backward_p->execute(astream,
                                 {{DNNL_ARG_DST, *dst_memory_p},
                                  {DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
                                  {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
  astream.wait();

  dx->set_layout(DataLayout::kMKLDNN);
  dx->set_format(GetMKLDNNFormat(*diff_src_memory_p));
170 171
}

172
template <typename T, dnnl::algorithm algorithm>
173
struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
174
  void operator()(const framework::ExecutionContext &ctx) const {
175 176 177 178
    eltwise_forward<T>(ctx, algorithm);
  }
};

179
template <typename T, dnnl::algorithm algorithm>
180
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
181
  void operator()(const framework::ExecutionContext &ctx) const {
182 183 184 185
    eltwise_grad<T>(ctx, algorithm);
  }
};

186 187 188 189 190 191 192
template <typename T, dnnl::algorithm algorithm>
struct MKLDNNActivationGradUseOutFunc : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    eltwise_grad_use_out<T>(ctx, algorithm);
  }
};

A
Adam 已提交
193 194 195 196 197
template <typename T>
struct GeluMKLDNNFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const bool approximate = ctx.Attr<bool>("approximate");
    if (approximate) {
198
      eltwise_forward<T>(ctx, dnnl::algorithm::eltwise_gelu_tanh);
A
Adam 已提交
199
    } else {
200
      eltwise_forward<T>(ctx, dnnl::algorithm::eltwise_gelu_erf);
A
Adam 已提交
201 202 203 204 205 206 207 208 209
    }
  }
};

template <typename T>
struct GeluMKLDNNGradFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const bool approximate = ctx.Attr<bool>("approximate");
    if (approximate) {
210
      eltwise_grad<T>(ctx, dnnl::algorithm::eltwise_gelu_tanh);
A
Adam 已提交
211
    } else {
212
      eltwise_grad<T>(ctx, dnnl::algorithm::eltwise_gelu_erf);
A
Adam 已提交
213 214 215 216
    }
  }
};

217 218 219 220 221 222 223
template <typename T>
struct SoftplusMKLDNNFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    custom_softplus_eltwise_forward<T>(ctx);
  }
};

224
template <typename T>
T
tensor-tang 已提交
225
using ReluMKLDNNFunctor =
226
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_relu>;
227

A
Adam 已提交
228 229
template <typename T>
using Relu6MKLDNNFunctor =
230
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_bounded_relu>;
A
Adam 已提交
231

232 233
template <typename T>
using SwishMKLDNNFunctor =
234
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_swish>;
235

J
jakpiase 已提交
236 237
template <typename T>
using HardSwishMKLDNNFunctor =
238
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_hardswish>;
J
jakpiase 已提交
239

240 241
template <typename T>
using SigmoidMKLDNNFunctor =
242
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_logistic>;
243

244
template <typename T>
T
tensor-tang 已提交
245
using TanhMKLDNNFunctor =
246
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_tanh>;
247 248

template <typename T>
T
tensor-tang 已提交
249
using SqrtMKLDNNFunctor =
250
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_sqrt>;
251 252

template <typename T>
253
using AbsMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_abs>;
254

J
jakpiase 已提交
255
template <typename T>
256
using EluMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_elu>;
J
jakpiase 已提交
257

258 259 260
template <typename T>
using ExpMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_exp>;

261
template <typename T>
T
tensor-tang 已提交
262
using ReluMKLDNNGradFunctor =
263
    MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_relu>;
264

A
Adam 已提交
265 266
template <typename T>
using Relu6MKLDNNGradFunctor =
267
    MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_bounded_relu>;
A
Adam 已提交
268

269 270
template <typename T>
using SwishMKLDNNGradFunctor =
271
    MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_swish>;
272

J
jakpiase 已提交
273 274
template <typename T>
using HardSwishMKLDNNGradFunctor =
275
    MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_hardswish>;
J
jakpiase 已提交
276

277
template <typename T>
278 279
using SigmoidMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
    T, dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>;
280

281
template <typename T>
282 283
using TanhMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
    T, dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>;
284 285

template <typename T>
286 287
using SqrtMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
    T, dnnl::algorithm::eltwise_sqrt_use_dst_for_bwd>;
288 289

template <typename T>
T
tensor-tang 已提交
290
using AbsMKLDNNGradFunctor =
291
    MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_abs>;
J
jakpiase 已提交
292 293

template <typename T>
294 295 296 297 298 299 300
using EluMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
    T, dnnl::algorithm::eltwise_elu_use_dst_for_bwd>;

template <typename T>
using ExpMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
    T, dnnl::algorithm::eltwise_exp_use_dst_for_bwd>;

301 302 303 304 305 306 307 308 309 310 311 312
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
  REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace,       \
                     ops::MKLDNNActivationKernel<ops::functor<float>>);    \
  REGISTER_OP_KERNEL(                                                      \
      act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace,               \
      ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);

313 314 315 316 317 318 319 320
#define REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(act_type, functor,             \
                                               grad_functor)                  \
  REGISTER_OP_KERNEL(                                                         \
      act_type, MKLDNN, ::paddle::platform::CPUPlace,                         \
      ops::MKLDNNActivationKernel<ops::functor<float>>,                       \
      ops::MKLDNNActivationKernel<ops::functor<paddle::platform::bfloat16>>); \
  REGISTER_OP_KERNEL(                                                         \
      act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace,                  \
321 322 323
      ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>,              \
      ops::MKLDNNActivationGradKernel<                                        \
          ops::grad_functor<paddle::platform::bfloat16>>);
324

J
jakpiase 已提交
325 326 327 328 329
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro)                            \
  __macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor);              \
  __macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor);           \
  __macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor);              \
  __macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
330
  __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradUseOutFunctor);           \
J
jakpiase 已提交
331
  __macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);                    \
332 333
  __macro(elu, EluMKLDNNFunctor, EluMKLDNNGradUseOutFunctor);              \
  __macro(exp, ExpMKLDNNFunctor, ExpMKLDNNGradUseOutFunctor);
334 335

FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
A
arlesniak 已提交
336 337
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor,
                                       ReluMKLDNNGradFunctor);
338 339
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(gelu, GeluMKLDNNFunctor,
                                       GeluMKLDNNGradFunctor);
340
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sigmoid, SigmoidMKLDNNFunctor,
341
                                       SigmoidMKLDNNGradUseOutFunctor);
J
jakpiase 已提交
342
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sqrt, SqrtMKLDNNFunctor,
343
                                       SqrtMKLDNNGradUseOutFunctor);
344 345 346 347 348

namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
    softplus, MKLDNN, paddle::platform::CPUPlace,
    ops::MKLDNNActivationKernel<ops::SoftplusMKLDNNFunctor<float>>);