activation_mkldnn_op.cc 13.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/fluid/operators/activation_op.h"
16
#include "paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h"
17
#include "paddle/fluid/platform/mkldnn_reuse.h"
18

19
namespace phi {
20
class DenseTensor;
21
}  // namespace phi
22

W
wanghuancoder 已提交
23
namespace paddle {
24
namespace framework {}  // namespace framework
W
wanghuancoder 已提交
25 26 27 28 29
namespace platform {
class MKLDNNDeviceContext;
}  // namespace platform
}  // namespace paddle

30 31 32
namespace paddle {
namespace operators {

33 34
using framework::DataLayout;
using framework::Tensor;
35 36 37
using dnnl::memory;
using dnnl::primitive;
using dnnl::stream;
38 39 40
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
41

42 43 44 45 46 47
template <typename Functor>
class MKLDNNActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *x = ctx.Input<Tensor>("X");
48 49 50 51 52 53
    PADDLE_ENFORCE_EQ(
        x->layout(), DataLayout::kMKLDNN,
        platform::errors::InvalidArgument("Wrong layout set for X tensor"));
    PADDLE_ENFORCE_NE(
        x->format(), MKLDNNMemoryFormat::undef,
        platform::errors::InvalidArgument("Wrong format set for X tensor"));
54 55 56 57 58

    Functor functor;
    functor(ctx);
  }
};
K
Krzysztof Binias 已提交
59

60 61 62 63 64 65
template <typename Functor>
class MKLDNNActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
66
    PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
67 68
                      platform::errors::InvalidArgument(
                          "Wrong layout set for Input OutGrad tensor"));
A
Adam 已提交
69
    PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
70 71
                      platform::errors::InvalidArgument(
                          "Wrong format set for Input OutGrad tensor"));
72 73 74 75 76 77 78 79

    Functor functor;
    functor(ctx);
  }
};

template <typename T>
void eltwise_forward(const framework::ExecutionContext &ctx,
80
                     dnnl::algorithm algorithm) {
81 82 83
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
                    paddle::platform::errors::PreconditionNotMet(
                        "Operator DNNL eletwise_forward must use CPUPlace"));
84
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
85
  const auto &mkldnn_engine = dev_ctx.GetEngine();
86

87
  const auto *x = ctx.Input<Tensor>("X");
88
  auto *out = ctx.Output<Tensor>("Out");
89

90
  bool is_inplaced = x->IsSharedBufferWith(*out);
91

92 93
  platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
                                               ctx.GetPlace(), x);
94

95
  auto src_memory_p = handler.AcquireSrcMemory(x);
96 97 98
  std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
  if (is_inplaced) {
    dst_memory_p = src_memory_p;
99
    out->mutable_data<T>(ctx.GetPlace());
100
  } else {
101
    dst_memory_p = handler.AcquireDstMemory(out);
102
  }
A
Adam 已提交
103
  auto activation_p = handler.AcquireForwardPrimitive();
104

105
  auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
106 107
  activation_p->execute(
      astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}});
A
Adam 已提交
108
  astream.wait();
109

110 111
  out->set_layout(DataLayout::kMKLDNN);
  out->set_format(GetMKLDNNFormat(*dst_memory_p));
112 113
}

114 115
template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
116
                  dnnl::algorithm algorithm) {
117
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
118
  const auto &mkldnn_engine = dev_ctx.GetEngine();
119

120
  const auto *x = ctx.Input<Tensor>("X");
121 122
  const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
  auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
123

124
  platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
125
                                               ctx.GetPlace(), x, dout);
126

127
  auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
128 129
  auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
  auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
A
Adam 已提交
130 131
  auto activation_backward_p = handler.AcquireBackwardPrimitive();

132
  auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
A
Adam 已提交
133
  activation_backward_p->execute(astream,
134 135 136
                                 {{DNNL_ARG_SRC, *src_memory_p},
                                  {DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
                                  {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
A
Adam 已提交
137
  astream.wait();
138

139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
  dx->set_layout(DataLayout::kMKLDNN);
  dx->set_format(GetMKLDNNFormat(*diff_src_memory_p));
}

template <typename T>
void eltwise_grad_use_out(const framework::ExecutionContext &ctx,
                          dnnl::algorithm algorithm) {
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
  const auto &mkldnn_engine = dev_ctx.GetEngine();

  const auto *out = ctx.Input<Tensor>("Out");
  const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
  auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));

  platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
                                               ctx.GetPlace(), out, dout);

  auto dst_memory_p = handler.AcquireBackwardSrcMemory(out);
  auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
  auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
  auto activation_backward_p = handler.AcquireBackwardPrimitive();

  auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
  activation_backward_p->execute(astream,
                                 {{DNNL_ARG_DST, *dst_memory_p},
                                  {DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
                                  {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
  astream.wait();

  dx->set_layout(DataLayout::kMKLDNN);
  dx->set_format(GetMKLDNNFormat(*diff_src_memory_p));
170 171
}

172
template <typename T, dnnl::algorithm algorithm>
173
struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
174
  void operator()(const framework::ExecutionContext &ctx) const {
175 176 177 178
    eltwise_forward<T>(ctx, algorithm);
  }
};

179
template <typename T, dnnl::algorithm algorithm>
180
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
181
  void operator()(const framework::ExecutionContext &ctx) const {
182 183 184 185
    eltwise_grad<T>(ctx, algorithm);
  }
};

186 187 188 189 190 191 192
template <typename T, dnnl::algorithm algorithm>
struct MKLDNNActivationGradUseOutFunc : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    eltwise_grad_use_out<T>(ctx, algorithm);
  }
};

A
Adam 已提交
193 194 195 196 197
template <typename T>
struct GeluMKLDNNFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const bool approximate = ctx.Attr<bool>("approximate");
    if (approximate) {
198
      eltwise_forward<T>(ctx, dnnl::algorithm::eltwise_gelu_tanh);
A
Adam 已提交
199
    } else {
200
      eltwise_forward<T>(ctx, dnnl::algorithm::eltwise_gelu_erf);
A
Adam 已提交
201 202 203 204 205 206 207 208 209
    }
  }
};

template <typename T>
struct GeluMKLDNNGradFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const bool approximate = ctx.Attr<bool>("approximate");
    if (approximate) {
210
      eltwise_grad<T>(ctx, dnnl::algorithm::eltwise_gelu_tanh);
A
Adam 已提交
211
    } else {
212
      eltwise_grad<T>(ctx, dnnl::algorithm::eltwise_gelu_erf);
A
Adam 已提交
213 214 215 216
    }
  }
};

217 218 219 220 221 222 223
template <typename T>
struct SoftplusMKLDNNFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    custom_softplus_eltwise_forward<T>(ctx);
  }
};

224
template <typename T>
T
tensor-tang 已提交
225
using ReluMKLDNNFunctor =
226
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_relu>;
227

A
Adam 已提交
228 229
template <typename T>
using Relu6MKLDNNFunctor =
230
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_bounded_relu>;
A
Adam 已提交
231

232 233
template <typename T>
using SwishMKLDNNFunctor =
234
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_swish>;
235

J
jakpiase 已提交
236 237
template <typename T>
using HardSwishMKLDNNFunctor =
238
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_hardswish>;
J
jakpiase 已提交
239

240 241 242 243
template <typename T>
using MishMKLDNNFunctor =
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_mish>;

244 245
template <typename T>
using SigmoidMKLDNNFunctor =
246
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_logistic>;
247

248
template <typename T>
T
tensor-tang 已提交
249
using TanhMKLDNNFunctor =
250
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_tanh>;
251 252

template <typename T>
T
tensor-tang 已提交
253
using SqrtMKLDNNFunctor =
254
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_sqrt>;
255 256

template <typename T>
257
using AbsMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_abs>;
258

J
jakpiase 已提交
259
template <typename T>
260
using EluMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_elu>;
J
jakpiase 已提交
261

262 263 264
template <typename T>
using ExpMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_exp>;

265 266 267 268
template <typename T>
using RoundMKLDNNFunctor =
    MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_round>;

269
template <typename T>
T
tensor-tang 已提交
270
using ReluMKLDNNGradFunctor =
271
    MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_relu>;
272

A
Adam 已提交
273 274
template <typename T>
using Relu6MKLDNNGradFunctor =
275
    MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_bounded_relu>;
A
Adam 已提交
276

277 278
template <typename T>
using SwishMKLDNNGradFunctor =
279
    MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_swish>;
280

J
jakpiase 已提交
281 282
template <typename T>
using HardSwishMKLDNNGradFunctor =
283
    MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_hardswish>;
J
jakpiase 已提交
284

285 286 287 288
template <typename T>
using MishMKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_mish>;

289
template <typename T>
290 291
using SigmoidMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
    T, dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>;
292

293
template <typename T>
294 295
using TanhMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
    T, dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>;
296 297

template <typename T>
298 299
using SqrtMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
    T, dnnl::algorithm::eltwise_sqrt_use_dst_for_bwd>;
300 301

template <typename T>
T
tensor-tang 已提交
302
using AbsMKLDNNGradFunctor =
303
    MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_abs>;
J
jakpiase 已提交
304 305

template <typename T>
306 307 308 309 310 311 312
using EluMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
    T, dnnl::algorithm::eltwise_elu_use_dst_for_bwd>;

template <typename T>
using ExpMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
    T, dnnl::algorithm::eltwise_exp_use_dst_for_bwd>;

313 314 315 316 317
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

318
#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor)    \
319 320 321 322 323 324
  REGISTER_OP_KERNEL(                                                         \
      act_type, MKLDNN, ::paddle::platform::CPUPlace,                         \
      ops::MKLDNNActivationKernel<ops::functor<float>>,                       \
      ops::MKLDNNActivationKernel<ops::functor<paddle::platform::bfloat16>>); \
  REGISTER_OP_KERNEL(                                                         \
      act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace,                  \
325 326 327
      ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>,              \
      ops::MKLDNNActivationGradKernel<                                        \
          ops::grad_functor<paddle::platform::bfloat16>>);
328

329 330 331 332
#define REGISTER_ACTIVATION_MKLDNN_KERNEL_FWD_ONLY(act_type, functor) \
  REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace,  \
                     ops::MKLDNNActivationKernel<ops::functor<float>>);

J
jakpiase 已提交
333 334
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro)                            \
  __macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);                    \
335
  __macro(elu, EluMKLDNNFunctor, EluMKLDNNGradUseOutFunctor);              \
336 337 338 339 340 341 342 343 344 345 346
  __macro(exp, ExpMKLDNNFunctor, ExpMKLDNNGradUseOutFunctor);              \
  __macro(gelu, GeluMKLDNNFunctor, GeluMKLDNNGradFunctor);                 \
  __macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
  __macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor);           \
  __macro(mish, MishMKLDNNFunctor, MishMKLDNNGradFunctor);                 \
  __macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor);                 \
  __macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor);              \
  __macro(sigmoid, SigmoidMKLDNNFunctor, SigmoidMKLDNNGradUseOutFunctor);  \
  __macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradUseOutFunctor);           \
  __macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor);              \
  __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradUseOutFunctor);
347 348

FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
349

350
REGISTER_ACTIVATION_MKLDNN_KERNEL_FWD_ONLY(round, RoundMKLDNNFunctor);
351 352 353 354

namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
    softplus, MKLDNN, paddle::platform::CPUPlace,
355 356 357
    ops::MKLDNNActivationKernel<ops::SoftplusMKLDNNFunctor<float>>,
    ops::MKLDNNActivationKernel<
        ops::SoftplusMKLDNNFunctor<paddle::platform::bfloat16>>);