activation_mkldnn_op.cc 11.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/fluid/operators/activation_op.h"
16
#include "paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h"
17
#include "paddle/fluid/platform/mkldnn_reuse.h"
18

W
wanghuancoder 已提交
19 20 21 22 23 24 25 26 27
namespace paddle {
namespace framework {
class Tensor;
}  // namespace framework
namespace platform {
class MKLDNNDeviceContext;
}  // namespace platform
}  // namespace paddle

28 29 30
namespace paddle {
namespace operators {

31 32 33 34 35 36 37 38
using framework::DataLayout;
using framework::Tensor;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
39

40 41 42 43 44 45
template <typename Functor>
class MKLDNNActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *x = ctx.Input<Tensor>("X");
46 47 48 49 50 51
    PADDLE_ENFORCE_EQ(
        x->layout(), DataLayout::kMKLDNN,
        platform::errors::InvalidArgument("Wrong layout set for X tensor"));
    PADDLE_ENFORCE_NE(
        x->format(), MKLDNNMemoryFormat::undef,
        platform::errors::InvalidArgument("Wrong format set for X tensor"));
52 53 54 55 56

    Functor functor;
    functor(ctx);
  }
};
K
Krzysztof Binias 已提交
57

58 59 60 61 62 63
template <typename Functor>
class MKLDNNActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
64
    PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
65 66
                      platform::errors::InvalidArgument(
                          "Wrong layout set for Input OutGrad tensor"));
A
Adam 已提交
67
    PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
68 69
                      platform::errors::InvalidArgument(
                          "Wrong format set for Input OutGrad tensor"));
70 71 72 73 74 75 76 77

    Functor functor;
    functor(ctx);
  }
};

template <typename T>
void eltwise_forward(const framework::ExecutionContext &ctx,
A
Adam 已提交
78
                     mkldnn::algorithm algorithm) {
79 80 81
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
                    paddle::platform::errors::PreconditionNotMet(
                        "Operator DNNL eletwise_forward must use CPUPlace"));
82
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
83
  const auto &mkldnn_engine = dev_ctx.GetEngine();
84

85 86
  const auto *x = ctx.Input<Tensor>("X");
  auto *y = ctx.Output<Tensor>("Out");
87

88
  bool is_inplaced = x->IsSharedBufferWith(*y);
89

90 91
  platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
                                               ctx.GetPlace(), x);
92

93
  auto src_memory_p = handler.AcquireSrcMemory(x);
94
  auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y);
A
Adam 已提交
95
  auto activation_p = handler.AcquireForwardPrimitive();
96

97
  auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
A
Adam 已提交
98 99 100
  activation_p->execute(astream, {{MKLDNN_ARG_FROM, *src_memory_p},
                                  {MKLDNN_ARG_TO, *dst_memory_p}});
  astream.wait();
101

102
  y->set_layout(DataLayout::kMKLDNN);
103
  y->set_format(GetMKLDNNFormat(*dst_memory_p));
104 105
}

106 107
template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
A
Adam 已提交
108
                  mkldnn::algorithm algorithm) {
109
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
110
  const auto &mkldnn_engine = dev_ctx.GetEngine();
111

112
  const auto *x = ctx.Input<Tensor>("X");
113 114
  const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
  auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
115

116 117
  platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
                                               ctx.GetPlace(), x, diff_y);
118

119 120 121
  auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
  auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y);
  auto diff_src_memory_p = handler.AcquireDiffSrcMemory(diff_x);
A
Adam 已提交
122 123
  auto activation_backward_p = handler.AcquireBackwardPrimitive();

124
  auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
A
Adam 已提交
125 126 127 128 129
  activation_backward_p->execute(astream,
                                 {{MKLDNN_ARG_SRC, *src_memory_p},
                                  {MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p},
                                  {MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}});
  astream.wait();
130

131
  diff_x->set_layout(DataLayout::kMKLDNN);
132
  diff_x->set_format(GetMKLDNNFormat(*diff_src_memory_p));
133 134 135 136
}

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
137
  void operator()(const framework::ExecutionContext &ctx) const {
138 139 140 141 142 143
    eltwise_forward<T>(ctx, algorithm);
  }
};

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
144
  void operator()(const framework::ExecutionContext &ctx) const {
145 146 147 148
    eltwise_grad<T>(ctx, algorithm);
  }
};

A
Adam 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
template <typename T>
struct GeluMKLDNNFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const bool approximate = ctx.Attr<bool>("approximate");
    if (approximate) {
      eltwise_forward<T>(ctx, mkldnn::algorithm::eltwise_gelu_tanh);
    } else {
      eltwise_forward<T>(ctx, mkldnn::algorithm::eltwise_gelu_erf);
    }
  }
};

template <typename T>
struct GeluMKLDNNGradFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const bool approximate = ctx.Attr<bool>("approximate");
    if (approximate) {
      eltwise_grad<T>(ctx, mkldnn::algorithm::eltwise_gelu_tanh);
    } else {
      eltwise_grad<T>(ctx, mkldnn::algorithm::eltwise_gelu_erf);
    }
  }
};

173 174 175 176 177 178 179
template <typename T>
struct SoftplusMKLDNNFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    custom_softplus_eltwise_forward<T>(ctx);
  }
};

180
template <typename T>
T
tensor-tang 已提交
181
using ReluMKLDNNFunctor =
182 183
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;

A
Adam 已提交
184 185 186 187
template <typename T>
using Relu6MKLDNNFunctor =
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_bounded_relu>;

188 189 190 191
template <typename T>
using SwishMKLDNNFunctor =
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_swish>;

J
jakpiase 已提交
192 193 194 195
template <typename T>
using HardSwishMKLDNNFunctor =
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_hardswish>;

196 197 198 199
template <typename T>
using SigmoidMKLDNNFunctor =
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_logistic>;

200
template <typename T>
T
tensor-tang 已提交
201
using TanhMKLDNNFunctor =
202 203 204
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
205
using SqrtMKLDNNFunctor =
206 207 208
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
209
using AbsMKLDNNFunctor =
210 211
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_abs>;

J
jakpiase 已提交
212 213 214 215
template <typename T>
using EluMKLDNNFunctor =
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_elu>;

216
template <typename T>
T
tensor-tang 已提交
217
using ReluMKLDNNGradFunctor =
218 219
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;

A
Adam 已提交
220 221 222 223
template <typename T>
using Relu6MKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_bounded_relu>;

224 225 226 227
template <typename T>
using SwishMKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_swish>;

J
jakpiase 已提交
228 229 230 231
template <typename T>
using HardSwishMKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_hardswish>;

232 233 234 235
template <typename T>
using SigmoidMKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_logistic>;

236
template <typename T>
T
tensor-tang 已提交
237
using TanhMKLDNNGradFunctor =
238 239 240
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
241
using SqrtMKLDNNGradFunctor =
242 243 244
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
245
using AbsMKLDNNGradFunctor =
246
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_abs>;
J
jakpiase 已提交
247 248 249 250

template <typename T>
using EluMKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_elu>;
251 252 253 254 255 256 257 258 259 260 261 262
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
  REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace,       \
                     ops::MKLDNNActivationKernel<ops::functor<float>>);    \
  REGISTER_OP_KERNEL(                                                      \
      act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace,               \
      ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);

263 264 265 266 267 268 269 270
#define REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(act_type, functor,             \
                                               grad_functor)                  \
  REGISTER_OP_KERNEL(                                                         \
      act_type, MKLDNN, ::paddle::platform::CPUPlace,                         \
      ops::MKLDNNActivationKernel<ops::functor<float>>,                       \
      ops::MKLDNNActivationKernel<ops::functor<paddle::platform::bfloat16>>); \
  REGISTER_OP_KERNEL(                                                         \
      act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace,                  \
271 272 273
      ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>,              \
      ops::MKLDNNActivationGradKernel<                                        \
          ops::grad_functor<paddle::platform::bfloat16>>);
274

J
jakpiase 已提交
275 276 277 278 279 280 281 282 283
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro)                            \
  __macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor);              \
  __macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor);           \
  __macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor);              \
  __macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
  __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor);                 \
  __macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor);                 \
  __macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);                    \
  __macro(elu, EluMKLDNNFunctor, EluMKLDNNGradFunctor);
284 285

FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
A
arlesniak 已提交
286 287
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor,
                                       ReluMKLDNNGradFunctor);
288 289
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(gelu, GeluMKLDNNFunctor,
                                       GeluMKLDNNGradFunctor);
290 291
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sigmoid, SigmoidMKLDNNFunctor,
                                       SigmoidMKLDNNGradFunctor);
292 293 294 295 296

namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
    softplus, MKLDNN, paddle::platform::CPUPlace,
    ops::MKLDNNActivationKernel<ops::SoftplusMKLDNNFunctor<float>>);