activation_mkldnn_op.cc 10.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/fluid/operators/activation_op.h"
16
#include "paddle/fluid/platform/mkldnn_reuse.h"
17

W
wanghuancoder 已提交
18 19 20 21 22 23 24 25 26
namespace paddle {
namespace framework {
class Tensor;
}  // namespace framework
namespace platform {
class MKLDNNDeviceContext;
}  // namespace platform
}  // namespace paddle

27 28 29
namespace paddle {
namespace operators {

30 31 32 33 34 35 36 37
using framework::DataLayout;
using framework::Tensor;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
38

39 40 41 42 43 44
template <typename Functor>
class MKLDNNActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *x = ctx.Input<Tensor>("X");
45 46 47 48 49 50
    PADDLE_ENFORCE_EQ(
        x->layout(), DataLayout::kMKLDNN,
        platform::errors::InvalidArgument("Wrong layout set for X tensor"));
    PADDLE_ENFORCE_NE(
        x->format(), MKLDNNMemoryFormat::undef,
        platform::errors::InvalidArgument("Wrong format set for X tensor"));
51 52 53 54 55

    Functor functor;
    functor(ctx);
  }
};
K
Krzysztof Binias 已提交
56

57 58 59 60 61 62
template <typename Functor>
class MKLDNNActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
63
    PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
64 65
                      platform::errors::InvalidArgument(
                          "Wrong layout set for Input OutGrad tensor"));
A
Adam 已提交
66
    PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
67 68
                      platform::errors::InvalidArgument(
                          "Wrong format set for Input OutGrad tensor"));
69 70 71 72 73 74 75 76

    Functor functor;
    functor(ctx);
  }
};

template <typename T>
void eltwise_forward(const framework::ExecutionContext &ctx,
A
Adam 已提交
77
                     mkldnn::algorithm algorithm) {
78 79 80
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
                    paddle::platform::errors::PreconditionNotMet(
                        "Operator DNNL eletwise_forward must use CPUPlace"));
81 82
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();

83 84
  const auto *x = ctx.Input<Tensor>("X");
  auto *y = ctx.Output<Tensor>("Out");
85

86 87 88 89 90 91
  T alpha = ctx.HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
  T beta = ctx.HasAttr("beta") ? ctx.Attr<T>("beta") : 0;

  // paddle uses beta but mkldnn uses alpha for swish
  if (algorithm == mkldnn::algorithm::eltwise_swish) {
    std::swap(alpha, beta);
A
Adam 已提交
92 93
  } else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
    alpha = ctx.Attr<T>("threshold");
94
  }
A
Adam 已提交
95

Y
Yihua Xu 已提交
96 97
  PADDLE_ENFORCE(
      x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4,
98
      platform::errors::Unimplemented("Input dim must be with 2, 3 or 4"));
Y
Yihua Xu 已提交
99

A
Adam 已提交
100
  auto src_tz = framework::vectorize<int64_t>(x->dims());
101

102
  auto src_format = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format();
103

104
  platform::ActivationMKLDNNHandler<T> handler(
105 106
      src_tz, algorithm, alpha, beta, src_format, dev_ctx, ctx.GetPlace(),
      ctx.InputName("X"));
107

108
  auto src_memory_p = handler.AcquireSrcMemory(x);
109 110
  auto dst_memory_p =
      x->IsSharedBufferWith(*y) ? src_memory_p : handler.AcquireDstMemory(y);
A
Adam 已提交
111
  auto activation_p = handler.AcquireForwardPrimitive();
112

A
Adam 已提交
113 114 115 116
  mkldnn::stream astream(dev_ctx.GetEngine());
  activation_p->execute(astream, {{MKLDNN_ARG_FROM, *src_memory_p},
                                  {MKLDNN_ARG_TO, *dst_memory_p}});
  astream.wait();
117

118
  y->set_layout(DataLayout::kMKLDNN);
119
  y->set_format(GetMKLDNNFormat(*dst_memory_p));
120 121
}

122 123
template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
A
Adam 已提交
124
                  mkldnn::algorithm algorithm) {
125 126
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();

127
  const auto *x = ctx.Input<Tensor>("X");
128 129
  const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
  auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
130

131 132 133 134 135 136
  T alpha = ctx.HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
  T beta = ctx.HasAttr("beta") ? ctx.Attr<T>("beta") : 0;

  // paddle uses beta but mkldnn uses alpha for swish
  if (algorithm == mkldnn::algorithm::eltwise_swish) {
    std::swap(alpha, beta);
A
Adam 已提交
137 138
  } else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
    alpha = ctx.Attr<T>("threshold");
139
  }
A
Adam 已提交
140

A
Adam 已提交
141
  auto diff_dst_tz = framework::vectorize<int64_t>(diff_y->dims());
K
Krzysztof Binias 已提交
142

143 144
  // diff_dst and src dims should be the same
  auto src_format =
145
      diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format();
146

147
  auto diff_y_format =
148
      diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : diff_y->format();
149

150 151
  platform::ActivationMKLDNNHandler<T> handler(
      diff_dst_tz, algorithm, alpha, beta, src_format, diff_y_format, dev_ctx,
H
hong 已提交
152
      ctx.GetPlace(), ctx.InputName("X"));
153

154 155 156
  auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
  auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y);
  auto diff_src_memory_p = handler.AcquireDiffSrcMemory(diff_x);
A
Adam 已提交
157 158 159 160 161 162 163 164
  auto activation_backward_p = handler.AcquireBackwardPrimitive();

  mkldnn::stream astream(dev_ctx.GetEngine());
  activation_backward_p->execute(astream,
                                 {{MKLDNN_ARG_SRC, *src_memory_p},
                                  {MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p},
                                  {MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}});
  astream.wait();
165

166
  diff_x->set_layout(DataLayout::kMKLDNN);
167
  diff_x->set_format(GetMKLDNNFormat(*diff_src_memory_p));
168 169 170 171
}

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
172
  void operator()(const framework::ExecutionContext &ctx) const {
173 174 175 176 177 178
    eltwise_forward<T>(ctx, algorithm);
  }
};

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
179
  void operator()(const framework::ExecutionContext &ctx) const {
180 181 182 183
    eltwise_grad<T>(ctx, algorithm);
  }
};

A
Adam 已提交
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
template <typename T>
struct GeluMKLDNNFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const bool approximate = ctx.Attr<bool>("approximate");
    if (approximate) {
      eltwise_forward<T>(ctx, mkldnn::algorithm::eltwise_gelu_tanh);
    } else {
      eltwise_forward<T>(ctx, mkldnn::algorithm::eltwise_gelu_erf);
    }
  }
};

template <typename T>
struct GeluMKLDNNGradFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const bool approximate = ctx.Attr<bool>("approximate");
    if (approximate) {
      eltwise_grad<T>(ctx, mkldnn::algorithm::eltwise_gelu_tanh);
    } else {
      eltwise_grad<T>(ctx, mkldnn::algorithm::eltwise_gelu_erf);
    }
  }
};

208
template <typename T>
T
tensor-tang 已提交
209
using ReluMKLDNNFunctor =
210 211
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;

A
Adam 已提交
212 213 214 215
template <typename T>
using Relu6MKLDNNFunctor =
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_bounded_relu>;

216 217 218 219
template <typename T>
using SwishMKLDNNFunctor =
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_swish>;

220 221 222 223
template <typename T>
using SigmoidMKLDNNFunctor =
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_logistic>;

224
template <typename T>
T
tensor-tang 已提交
225
using TanhMKLDNNFunctor =
226 227 228
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
229
using SqrtMKLDNNFunctor =
230 231 232
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
233
using AbsMKLDNNFunctor =
234 235 236
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_abs>;

template <typename T>
T
tensor-tang 已提交
237
using ReluMKLDNNGradFunctor =
238 239
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;

A
Adam 已提交
240 241 242 243
template <typename T>
using Relu6MKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_bounded_relu>;

244 245 246 247
template <typename T>
using SwishMKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_swish>;

248 249 250 251
template <typename T>
using SigmoidMKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_logistic>;

252
template <typename T>
T
tensor-tang 已提交
253
using TanhMKLDNNGradFunctor =
254 255 256
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
257
using SqrtMKLDNNGradFunctor =
258 259 260
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
261
using AbsMKLDNNGradFunctor =
262 263 264 265 266 267 268 269 270 271 272 273 274
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_abs>;
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
  REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace,       \
                     ops::MKLDNNActivationKernel<ops::functor<float>>);    \
  REGISTER_OP_KERNEL(                                                      \
      act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace,               \
      ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);

275 276
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro)                     \
  __macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor);          \
A
Adam 已提交
277
  __macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor);       \
278 279 280 281 282 283
  __macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor);    \
  __macro(gelu, GeluMKLDNNFunctor, GeluMKLDNNGradFunctor);          \
  __macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor);       \
  __macro(sigmoid, SigmoidMKLDNNFunctor, SigmoidMKLDNNGradFunctor); \
  __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor);          \
  __macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor);          \
T
tensor-tang 已提交
284
  __macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);
285 286

FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);