activation_mkldnn_op.cc 10.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/fluid/operators/activation_op.h"
16
#include "paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h"
17
#include "paddle/fluid/platform/mkldnn_reuse.h"
18

W
wanghuancoder 已提交
19 20 21 22 23 24 25 26 27
namespace paddle {
namespace framework {
class Tensor;
}  // namespace framework
namespace platform {
class MKLDNNDeviceContext;
}  // namespace platform
}  // namespace paddle

28 29 30
namespace paddle {
namespace operators {

31 32 33 34 35 36 37 38
using framework::DataLayout;
using framework::Tensor;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
39

40 41 42 43 44 45
template <typename Functor>
class MKLDNNActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *x = ctx.Input<Tensor>("X");
46 47 48 49 50 51
    PADDLE_ENFORCE_EQ(
        x->layout(), DataLayout::kMKLDNN,
        platform::errors::InvalidArgument("Wrong layout set for X tensor"));
    PADDLE_ENFORCE_NE(
        x->format(), MKLDNNMemoryFormat::undef,
        platform::errors::InvalidArgument("Wrong format set for X tensor"));
52 53 54 55 56

    Functor functor;
    functor(ctx);
  }
};
K
Krzysztof Binias 已提交
57

58 59 60 61 62 63
template <typename Functor>
class MKLDNNActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
64
    PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
65 66
                      platform::errors::InvalidArgument(
                          "Wrong layout set for Input OutGrad tensor"));
A
Adam 已提交
67
    PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
68 69
                      platform::errors::InvalidArgument(
                          "Wrong format set for Input OutGrad tensor"));
70 71 72 73 74 75 76 77

    Functor functor;
    functor(ctx);
  }
};

template <typename T>
void eltwise_forward(const framework::ExecutionContext &ctx,
A
Adam 已提交
78
                     mkldnn::algorithm algorithm) {
79 80 81
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
                    paddle::platform::errors::PreconditionNotMet(
                        "Operator DNNL eletwise_forward must use CPUPlace"));
82
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
83
  const auto &mkldnn_engine = dev_ctx.GetEngine();
84

85 86
  const auto *x = ctx.Input<Tensor>("X");
  auto *y = ctx.Output<Tensor>("Out");
87

88
  bool is_inplaced = x->IsSharedBufferWith(*y);
89

90 91
  platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
                                               ctx.GetPlace(), x);
92

93
  auto src_memory_p = handler.AcquireSrcMemory(x);
94
  auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y);
A
Adam 已提交
95
  auto activation_p = handler.AcquireForwardPrimitive();
96

97
  auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
A
Adam 已提交
98 99 100
  activation_p->execute(astream, {{MKLDNN_ARG_FROM, *src_memory_p},
                                  {MKLDNN_ARG_TO, *dst_memory_p}});
  astream.wait();
101

102
  y->set_layout(DataLayout::kMKLDNN);
103
  y->set_format(GetMKLDNNFormat(*dst_memory_p));
104 105
}

106 107
template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
A
Adam 已提交
108
                  mkldnn::algorithm algorithm) {
109
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
110
  const auto &mkldnn_engine = dev_ctx.GetEngine();
111

112
  const auto *x = ctx.Input<Tensor>("X");
113 114
  const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
  auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
115

116 117
  platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
                                               ctx.GetPlace(), x, diff_y);
118

119 120 121
  auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
  auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y);
  auto diff_src_memory_p = handler.AcquireDiffSrcMemory(diff_x);
A
Adam 已提交
122 123
  auto activation_backward_p = handler.AcquireBackwardPrimitive();

124
  auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
A
Adam 已提交
125 126 127 128 129
  activation_backward_p->execute(astream,
                                 {{MKLDNN_ARG_SRC, *src_memory_p},
                                  {MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p},
                                  {MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}});
  astream.wait();
130

131
  diff_x->set_layout(DataLayout::kMKLDNN);
132
  diff_x->set_format(GetMKLDNNFormat(*diff_src_memory_p));
133 134 135 136
}

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
137
  void operator()(const framework::ExecutionContext &ctx) const {
138 139 140 141 142 143
    eltwise_forward<T>(ctx, algorithm);
  }
};

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
144
  void operator()(const framework::ExecutionContext &ctx) const {
145 146 147 148
    eltwise_grad<T>(ctx, algorithm);
  }
};

A
Adam 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
template <typename T>
struct GeluMKLDNNFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const bool approximate = ctx.Attr<bool>("approximate");
    if (approximate) {
      eltwise_forward<T>(ctx, mkldnn::algorithm::eltwise_gelu_tanh);
    } else {
      eltwise_forward<T>(ctx, mkldnn::algorithm::eltwise_gelu_erf);
    }
  }
};

template <typename T>
struct GeluMKLDNNGradFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const bool approximate = ctx.Attr<bool>("approximate");
    if (approximate) {
      eltwise_grad<T>(ctx, mkldnn::algorithm::eltwise_gelu_tanh);
    } else {
      eltwise_grad<T>(ctx, mkldnn::algorithm::eltwise_gelu_erf);
    }
  }
};

173 174 175 176 177 178 179
template <typename T>
struct SoftplusMKLDNNFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    custom_softplus_eltwise_forward<T>(ctx);
  }
};

180
template <typename T>
T
tensor-tang 已提交
181
using ReluMKLDNNFunctor =
182 183
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;

A
Adam 已提交
184 185 186 187
template <typename T>
using Relu6MKLDNNFunctor =
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_bounded_relu>;

188 189 190 191
template <typename T>
using SwishMKLDNNFunctor =
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_swish>;

J
jakpiase 已提交
192 193 194 195
template <typename T>
using HardSwishMKLDNNFunctor =
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_hardswish>;

196 197 198 199
template <typename T>
using SigmoidMKLDNNFunctor =
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_logistic>;

200
template <typename T>
T
tensor-tang 已提交
201
using TanhMKLDNNFunctor =
202 203 204
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
205
using SqrtMKLDNNFunctor =
206 207 208
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
209
using AbsMKLDNNFunctor =
210 211 212
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_abs>;

template <typename T>
T
tensor-tang 已提交
213
using ReluMKLDNNGradFunctor =
214 215
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;

A
Adam 已提交
216 217 218 219
template <typename T>
using Relu6MKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_bounded_relu>;

220 221 222 223
template <typename T>
using SwishMKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_swish>;

J
jakpiase 已提交
224 225 226 227
template <typename T>
using HardSwishMKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_hardswish>;

228 229 230 231
template <typename T>
using SigmoidMKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_logistic>;

232
template <typename T>
T
tensor-tang 已提交
233
using TanhMKLDNNGradFunctor =
234 235 236
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
237
using SqrtMKLDNNGradFunctor =
238 239 240
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
241
using AbsMKLDNNGradFunctor =
242 243 244 245 246 247 248 249 250 251 252 253 254
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_abs>;
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
  REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace,       \
                     ops::MKLDNNActivationKernel<ops::functor<float>>);    \
  REGISTER_OP_KERNEL(                                                      \
      act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace,               \
      ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);

255 256 257 258 259 260 261 262
#define REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(act_type, functor,             \
                                               grad_functor)                  \
  REGISTER_OP_KERNEL(                                                         \
      act_type, MKLDNN, ::paddle::platform::CPUPlace,                         \
      ops::MKLDNNActivationKernel<ops::functor<float>>,                       \
      ops::MKLDNNActivationKernel<ops::functor<paddle::platform::bfloat16>>); \
  REGISTER_OP_KERNEL(                                                         \
      act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace,                  \
263 264 265
      ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>,              \
      ops::MKLDNNActivationGradKernel<                                        \
          ops::grad_functor<paddle::platform::bfloat16>>);
266

J
jakpiase 已提交
267 268 269 270 271 272 273
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro)                           \
  __macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor);             \
  __macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor);          \
  __macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor);             \
  __macro(hardswish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
  __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor);                \
  __macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor);                \
T
tensor-tang 已提交
274
  __macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);
275 276

FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
A
arlesniak 已提交
277 278
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor,
                                       ReluMKLDNNGradFunctor);
279 280
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(gelu, GeluMKLDNNFunctor,
                                       GeluMKLDNNGradFunctor);
281 282
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sigmoid, SigmoidMKLDNNFunctor,
                                       SigmoidMKLDNNGradFunctor);
283 284 285 286 287

namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
    softplus, MKLDNN, paddle::platform::CPUPlace,
    ops::MKLDNNActivationKernel<ops::SoftplusMKLDNNFunctor<float>>);