activation_mkldnn_op.cc 10.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/fluid/operators/activation_op.h"
16
#include "paddle/fluid/platform/mkldnn_reuse.h"
17 18 19 20

namespace paddle {
namespace operators {

21 22 23 24 25 26 27 28
using framework::DataLayout;
using framework::Tensor;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
29 30

namespace {
K
Krzysztof Binias 已提交
31 32
std::string gethash(const mkldnn::memory::dims &operand_dims,
                    const mkldnn::algorithm algorithm) {
K
Krzysztof Binias 已提交
33 34 35 36 37 38 39 40
  auto dim2str = [](const mkldnn::memory::dims &operand_dims) {
    std::string dstr = "";
    for (size_t i = 0; i < operand_dims.size(); ++i) {
      dstr += std::to_string(operand_dims[i]) + "-";
    }
    return dstr;
  };
  return dim2str(operand_dims) + std::to_string(algorithm);
K
Krzysztof Binias 已提交
41
}
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
}  // namespace

template <typename Functor>
class MKLDNNActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *x = ctx.Input<Tensor>("X");
    PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
                       x->format() != memory::format::format_undef,
                   "Wrong layout/format set for Input x tensor");

    Functor functor;
    functor(ctx);
  }
};
K
Krzysztof Binias 已提交
58

59 60 61 62 63 64 65 66 67 68
template <typename Functor>
class MKLDNNActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
    PADDLE_ENFORCE(diff_y->layout() == DataLayout::kMKLDNN &&
                       diff_y->format() != memory::format::format_undef,
                   "Wrong layout/format set for Input OutGrad tensor");

69 70 71 72
    PADDLE_ENFORCE(
        !ctx.Attr<bool>("is_test"),
        "is_test attribute should be set to False in training phase.");

73 74 75 76 77 78 79
    Functor functor;
    functor(ctx);
  }
};

template <typename T>
void eltwise_forward(const framework::ExecutionContext &ctx,
A
Adam 已提交
80
                     mkldnn::algorithm algorithm) {
81 82 83 84 85
  PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
                 "It must use CPUPlace.");
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
  const auto &mkldnn_engine = dev_ctx.GetEngine();

86 87
  const auto *x = ctx.Input<Tensor>("X");
  auto *y = ctx.Output<Tensor>("Out");
88

89 90
  const T *x_data = x->data<T>();
  T *y_data = y->mutable_data<T>(ctx.GetPlace());
91

A
Adam 已提交
92 93 94
  const T alpha = ctx.op().HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
  const T beta = ctx.op().HasAttr("beta") ? ctx.Attr<T>("beta") : 0;

Y
Yihua Xu 已提交
95 96 97 98
  PADDLE_ENFORCE(
      x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4,
      "Input dim must be with 2, 3 or 4");

99 100
  std::vector<int> src_tz = framework::vectorize2int(x->dims());

101 102
  auto src_format =
      src_tz.size() == 2 ? mkldnn::memory::format::nc : x->format();
103

104 105
  bool is_test = ctx.Attr<bool>("is_test");

106
  std::string key = platform::MKLDNNHandler::GetHash(
A
Adam 已提交
107 108
      src_tz, std::to_string(algorithm) + std::to_string(alpha) +
                  std::to_string(beta) + ctx.op().Input("X"));
109 110

  // TODO(jczaja): Make it Thread safe
111
  // save input data and layout to be referred in backward path
112 113 114 115 116 117
  const std::string key_src_data = key + "@eltwise_fwd_src_data";
  const std::string key_src_layout = key + "@eltwise_fwd_src_layout";
  // Just in case some int8 models are run interchangebly
  // with float models then format maybe diffrent
  key += std::to_string(src_format);
  const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
118 119
  auto p_src_data = std::make_shared<const T *>(x_data);
  auto p_src_layout = std::make_shared<memory::format>(src_format);
120 121 122 123
  if (!is_test) {
    dev_ctx.SetBlob(key_src_data, p_src_data);
    dev_ctx.SetBlob(key_src_layout, p_src_layout);
  }
K
Krzysztof Binias 已提交
124

125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
  platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);

  auto md = platform::MKLDNNMemDesc(src_tz, platform::MKLDNNGetDataType<T>(),
                                    src_format);

  auto activation_pd = handler.AcquireActivationPrimitiveDescriptor(
      is_test ? mkldnn::prop_kind::forward_inference
              : mkldnn::prop_kind::forward_training,
      algorithm, md, alpha, beta);

  auto src_memory_p = handler.AcquireSrcMemory(md, to_void_cast<T>(x_data));
  // jczaja: Workaround, src_memory_p is needed in BWD so it has
  // to be accessible under key not dependant on TID
  if (!is_test) {
    dev_ctx.SetBlob(key_src_mem, src_memory_p);
K
Krzysztof Binias 已提交
140
  }
141

142 143 144 145
  auto dst_memory_p =
      handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(y_data));
  auto activation_p = handler.AcquireActivation(dst_memory_p, src_memory_p);

146
  // push primitive to stream and wait until it's executed
147
  std::vector<primitive> pipeline;
148
  pipeline.push_back(*activation_p);
149 150
  stream(stream::kind::eager).submit(pipeline).wait();

151
  y->set_layout(DataLayout::kMKLDNN);
152
  y->set_format(GetMKLDNNFormat(*dst_memory_p));
153 154
}

155 156
template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
A
Adam 已提交
157
                  mkldnn::algorithm algorithm) {
158 159 160
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
  const auto &mkldnn_engine = dev_ctx.GetEngine();

161 162
  const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
  auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
163

164 165
  const T *diff_y_data = diff_y->data<T>();
  T *diff_x_data = diff_x->mutable_data<T>(ctx.GetPlace());
166

A
Adam 已提交
167 168 169
  const T alpha = ctx.op().HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
  const T beta = ctx.op().HasAttr("beta") ? ctx.Attr<T>("beta") : 0;

170
  std::vector<int> diff_dst_tz = framework::vectorize2int(diff_y->dims());
K
Krzysztof Binias 已提交
171

172 173 174
  auto diff_y_format =
      diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : diff_y->format();

175 176 177 178
  auto diff_dst_md = platform::MKLDNNMemDesc(
      diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format);

  std::string key = platform::MKLDNNHandler::GetHash(
A
Adam 已提交
179 180
      diff_dst_tz, std::to_string(algorithm) + std::to_string(alpha) +
                       std::to_string(beta) + ctx.op().Input("X"));
181 182 183 184 185

  const std::string key_src_data = key + "@eltwise_fwd_src_data";
  const std::string key_src_layout = key + "@eltwise_fwd_src_layout";

  // Get Data from FWD op
186 187
  const auto p_src_layout =
      std::static_pointer_cast<memory::format>(dev_ctx.GetBlob(key_src_layout));
K
Krzysztof Binias 已提交
188 189
  const auto p_src_data =
      std::static_pointer_cast<T *>(dev_ctx.GetBlob(key_src_data));
190 191
  key += std::to_string(*p_src_layout);
  const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
192
  auto src_memory =
K
Krzysztof Binias 已提交
193
      std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
194 195
  PADDLE_ENFORCE(src_memory != nullptr,
                 "Fail to find src_memory in device context");
196
  src_memory->set_data_handle(*p_src_data);
197

198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
  platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);

  auto diff_dst_memory_p =
      handler.AcquireDiffDstMemory(diff_dst_md, to_void_cast<T>(diff_y_data));

  auto activation_backward_pd =
      handler.AcquireActivationBackwardPrimitiveDescriptor(
          algorithm, diff_dst_md, src_memory->get_primitive_desc().desc(),
          alpha, beta);

  auto diff_src_memory_p =
      handler.AcquireDiffSrcMemoryFromPrimitive(diff_x_data);

  auto activation_backward_p = handler.AcquireActivationBackward(
      diff_src_memory_p, diff_dst_memory_p, src_memory);
213 214

  // push primitive to stream and wait until it's executed
215
  std::vector<primitive> pipeline;
216
  pipeline.push_back(*activation_backward_p);
217 218
  stream(stream::kind::eager).submit(pipeline).wait();

219
  diff_x->set_layout(DataLayout::kMKLDNN);
220
  diff_x->set_format(GetMKLDNNFormat(*diff_src_memory_p));
221 222 223 224
}

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
225
  void operator()(const framework::ExecutionContext &ctx) const {
226 227 228 229 230 231
    eltwise_forward<T>(ctx, algorithm);
  }
};

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
232
  void operator()(const framework::ExecutionContext &ctx) const {
233 234 235 236 237
    eltwise_grad<T>(ctx, algorithm);
  }
};

template <typename T>
T
tensor-tang 已提交
238
using ReluMKLDNNFunctor =
239 240 241
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;

template <typename T>
T
tensor-tang 已提交
242
using TanhMKLDNNFunctor =
243 244 245
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
246
using SqrtMKLDNNFunctor =
247 248 249
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
250
using AbsMKLDNNFunctor =
251 252 253
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_abs>;

template <typename T>
T
tensor-tang 已提交
254
using ReluMKLDNNGradFunctor =
255 256 257
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;

template <typename T>
T
tensor-tang 已提交
258
using TanhMKLDNNGradFunctor =
259 260 261
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
262
using SqrtMKLDNNGradFunctor =
263 264 265
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
266
using AbsMKLDNNGradFunctor =
267 268 269 270 271 272 273 274 275 276 277 278 279
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_abs>;
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
  REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace,       \
                     ops::MKLDNNActivationKernel<ops::functor<float>>);    \
  REGISTER_OP_KERNEL(                                                      \
      act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace,               \
      ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);

A
Adam 已提交
280 281 282 283 284
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro)                  \
  __macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor);       \
  __macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
  __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor);       \
  __macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor);       \
T
tensor-tang 已提交
285
  __macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);
286 287

FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);