activation_mkldnn_op.cc 10.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/fluid/operators/activation_op.h"
16
#include "paddle/fluid/platform/mkldnn_reuse.h"
17

18 19 20 21 22 23 24 25 26
namespace paddle {
namespace framework {
class Tensor;
}  // namespace framework
namespace platform {
class MKLDNNDeviceContext;
}  // namespace platform
}  // namespace paddle

27 28 29
namespace paddle {
namespace operators {

30 31 32 33 34 35 36 37
using framework::DataLayout;
using framework::Tensor;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
38

39 40 41 42 43 44
template <typename Functor>
class MKLDNNActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *x = ctx.Input<Tensor>("X");
45 46 47 48 49 50
    PADDLE_ENFORCE_EQ(
        x->layout(), DataLayout::kMKLDNN,
        platform::errors::InvalidArgument("Wrong layout set for X tensor"));
    PADDLE_ENFORCE_NE(
        x->format(), MKLDNNMemoryFormat::undef,
        platform::errors::InvalidArgument("Wrong format set for X tensor"));
51 52 53 54 55

    Functor functor;
    functor(ctx);
  }
};
K
Krzysztof Binias 已提交
56

57 58 59 60 61 62
template <typename Functor>
class MKLDNNActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
63
    PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
64 65
                      platform::errors::InvalidArgument(
                          "Wrong layout set for Input OutGrad tensor"));
A
Adam 已提交
66
    PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
67 68
                      platform::errors::InvalidArgument(
                          "Wrong format set for Input OutGrad tensor"));
69 70 71 72 73 74 75 76

    Functor functor;
    functor(ctx);
  }
};

template <typename T>
void eltwise_forward(const framework::ExecutionContext &ctx,
77
                     mkldnn::algorithm algorithm) {
78 79 80
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
                    paddle::platform::errors::PreconditionNotMet(
                        "Operator DNNL eletwise_forward must use CPUPlace"));
81
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
82
  const auto &mkldnn_engine = dev_ctx.GetEngine();
83

84 85
  const auto *x = ctx.Input<Tensor>("X");
  auto *y = ctx.Output<Tensor>("Out");
86

87
  bool is_inplaced = x->IsSharedBufferWith(*y);
88

89 90
  platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
                                               ctx.GetPlace(), x);
91

92
  auto src_memory_p = handler.AcquireSrcMemory(x);
93
  auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y);
A
Adam 已提交
94
  auto activation_p = handler.AcquireForwardPrimitive();
95

96
  auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
A
Adam 已提交
97 98 99
  activation_p->execute(astream, {{MKLDNN_ARG_FROM, *src_memory_p},
                                  {MKLDNN_ARG_TO, *dst_memory_p}});
  astream.wait();
100

101
  y->set_layout(DataLayout::kMKLDNN);
102
  y->set_format(GetMKLDNNFormat(*dst_memory_p));
103 104
}

105 106
template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
107
                  mkldnn::algorithm algorithm) {
108
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
109
  const auto &mkldnn_engine = dev_ctx.GetEngine();
110

111
  const auto *x = ctx.Input<Tensor>("X");
112 113
  const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
  auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
114

115 116
  platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
                                               ctx.GetPlace(), x, diff_y);
117

118 119 120
  auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
  auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y);
  auto diff_src_memory_p = handler.AcquireDiffSrcMemory(diff_x);
A
Adam 已提交
121 122
  auto activation_backward_p = handler.AcquireBackwardPrimitive();

123
  auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
A
Adam 已提交
124 125 126 127 128
  activation_backward_p->execute(astream,
                                 {{MKLDNN_ARG_SRC, *src_memory_p},
                                  {MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p},
                                  {MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}});
  astream.wait();
129

130
  diff_x->set_layout(DataLayout::kMKLDNN);
131
  diff_x->set_format(GetMKLDNNFormat(*diff_src_memory_p));
132 133 134 135
}

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
136
  void operator()(const framework::ExecutionContext &ctx) const {
137 138 139 140 141 142
    eltwise_forward<T>(ctx, algorithm);
  }
};

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
143
  void operator()(const framework::ExecutionContext &ctx) const {
144 145 146 147
    eltwise_grad<T>(ctx, algorithm);
  }
};

A
Adam 已提交
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
template <typename T>
struct GeluMKLDNNFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const bool approximate = ctx.Attr<bool>("approximate");
    if (approximate) {
      eltwise_forward<T>(ctx, mkldnn::algorithm::eltwise_gelu_tanh);
    } else {
      eltwise_forward<T>(ctx, mkldnn::algorithm::eltwise_gelu_erf);
    }
  }
};

template <typename T>
struct GeluMKLDNNGradFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const bool approximate = ctx.Attr<bool>("approximate");
    if (approximate) {
      eltwise_grad<T>(ctx, mkldnn::algorithm::eltwise_gelu_tanh);
    } else {
      eltwise_grad<T>(ctx, mkldnn::algorithm::eltwise_gelu_erf);
    }
  }
};

172
template <typename T>
T
tensor-tang 已提交
173
using ReluMKLDNNFunctor =
174 175
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;

A
Adam 已提交
176 177 178 179
template <typename T>
using Relu6MKLDNNFunctor =
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_bounded_relu>;

180 181 182 183
template <typename T>
using SwishMKLDNNFunctor =
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_swish>;

184 185 186 187
template <typename T>
using HardSwishMKLDNNFunctor =
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_hardswish>;

188 189 190 191
template <typename T>
using SigmoidMKLDNNFunctor =
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_logistic>;

192
template <typename T>
T
tensor-tang 已提交
193
using TanhMKLDNNFunctor =
194 195 196
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
197
using SqrtMKLDNNFunctor =
198 199 200
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
201
using AbsMKLDNNFunctor =
202 203 204
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_abs>;

template <typename T>
T
tensor-tang 已提交
205
using ReluMKLDNNGradFunctor =
206 207
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;

A
Adam 已提交
208 209 210 211
template <typename T>
using Relu6MKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_bounded_relu>;

212 213 214 215
template <typename T>
using SwishMKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_swish>;

216 217 218 219
template <typename T>
using HardSwishMKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_hardswish>;

220 221 222 223
template <typename T>
using SigmoidMKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_logistic>;

224
template <typename T>
T
tensor-tang 已提交
225
using TanhMKLDNNGradFunctor =
226 227 228
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
229
using SqrtMKLDNNGradFunctor =
230 231 232
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
233
using AbsMKLDNNGradFunctor =
234 235 236 237 238 239 240 241 242 243 244 245 246
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_abs>;
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
  REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace,       \
                     ops::MKLDNNActivationKernel<ops::functor<float>>);    \
  REGISTER_OP_KERNEL(                                                      \
      act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace,               \
      ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);

247 248 249 250 251 252 253 254
#define REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(act_type, functor,             \
                                               grad_functor)                  \
  REGISTER_OP_KERNEL(                                                         \
      act_type, MKLDNN, ::paddle::platform::CPUPlace,                         \
      ops::MKLDNNActivationKernel<ops::functor<float>>,                       \
      ops::MKLDNNActivationKernel<ops::functor<paddle::platform::bfloat16>>); \
  REGISTER_OP_KERNEL(                                                         \
      act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace,                  \
255 256 257
      ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>,              \
      ops::MKLDNNActivationGradKernel<                                        \
          ops::grad_functor<paddle::platform::bfloat16>>);
258

259 260 261 262 263 264 265 266
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro)                           \
  __macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor);                \
  __macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor);             \
  __macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor);          \
  __macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor);             \
  __macro(hardswish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
  __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor);                \
  __macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor);                \
T
tensor-tang 已提交
267
  __macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);
268 269

FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
270 271
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(gelu, GeluMKLDNNFunctor,
                                       GeluMKLDNNGradFunctor);
272 273
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sigmoid, SigmoidMKLDNNFunctor,
                                       SigmoidMKLDNNGradFunctor);
反馈
建议
客服 返回
顶部