activation_mkldnn_op.cc 10.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/fluid/operators/activation_op.h"
16
#include "paddle/fluid/platform/mkldnn_reuse.h"
17 18 19 20

namespace paddle {
namespace operators {

21 22 23 24 25 26 27 28
using framework::DataLayout;
using framework::Tensor;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
29 30

namespace {
K
Krzysztof Binias 已提交
31 32
std::string gethash(const mkldnn::memory::dims &operand_dims,
                    const mkldnn::algorithm algorithm) {
K
Krzysztof Binias 已提交
33 34 35 36 37 38 39 40
  auto dim2str = [](const mkldnn::memory::dims &operand_dims) {
    std::string dstr = "";
    for (size_t i = 0; i < operand_dims.size(); ++i) {
      dstr += std::to_string(operand_dims[i]) + "-";
    }
    return dstr;
  };
  return dim2str(operand_dims) + std::to_string(algorithm);
K
Krzysztof Binias 已提交
41
}
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
}  // namespace

template <typename Functor>
class MKLDNNActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *x = ctx.Input<Tensor>("X");
    PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
                       x->format() != memory::format::format_undef,
                   "Wrong layout/format set for Input x tensor");

    Functor functor;
    functor(ctx);
  }
};
K
Krzysztof Binias 已提交
58

59 60 61 62 63 64 65 66 67 68
template <typename Functor>
class MKLDNNActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
    PADDLE_ENFORCE(diff_y->layout() == DataLayout::kMKLDNN &&
                       diff_y->format() != memory::format::format_undef,
                   "Wrong layout/format set for Input OutGrad tensor");

69 70 71 72
    PADDLE_ENFORCE(
        !ctx.Attr<bool>("is_test"),
        "is_test attribute should be set to False in training phase.");

73 74 75 76 77 78 79
    Functor functor;
    functor(ctx);
  }
};

template <typename T>
void eltwise_forward(const framework::ExecutionContext &ctx,
80 81
                     mkldnn::algorithm algorithm, const T alpha = 0,
                     const T beta = 0) {
82 83 84 85 86
  PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
                 "It must use CPUPlace.");
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
  const auto &mkldnn_engine = dev_ctx.GetEngine();

87 88
  const auto *x = ctx.Input<Tensor>("X");
  auto *y = ctx.Output<Tensor>("Out");
89

90 91
  const T *x_data = x->data<T>();
  T *y_data = y->mutable_data<T>(ctx.GetPlace());
92

Y
Yihua Xu 已提交
93 94 95 96
  PADDLE_ENFORCE(
      x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4,
      "Input dim must be with 2, 3 or 4");

97 98
  std::vector<int> src_tz = framework::vectorize2int(x->dims());

99 100
  auto src_format =
      src_tz.size() == 2 ? mkldnn::memory::format::nc : x->format();
101

102 103
  bool is_test = ctx.Attr<bool>("is_test");

104 105
  // TODO(jczaja): When adding leaky-relu , swish , elu make sure to extend key
  // with alpha, beta
106
  std::string key = platform::MKLDNNHandler::GetHash(
107
      src_tz, std::to_string(algorithm) + ctx.op().Output("Out"));
108 109

  // TODO(jczaja): Make it Thread safe
110
  // save input data and layout to be referred in backward path
111 112 113 114 115 116
  const std::string key_src_data = key + "@eltwise_fwd_src_data";
  const std::string key_src_layout = key + "@eltwise_fwd_src_layout";
  // Just in case some int8 models are run interchangebly
  // with float models then format maybe diffrent
  key += std::to_string(src_format);
  const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
117 118
  auto p_src_data = std::make_shared<const T *>(x_data);
  auto p_src_layout = std::make_shared<memory::format>(src_format);
119 120 121 122
  if (!is_test) {
    dev_ctx.SetBlob(key_src_data, p_src_data);
    dev_ctx.SetBlob(key_src_layout, p_src_layout);
  }
K
Krzysztof Binias 已提交
123

124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
  platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);

  auto md = platform::MKLDNNMemDesc(src_tz, platform::MKLDNNGetDataType<T>(),
                                    src_format);

  auto activation_pd = handler.AcquireActivationPrimitiveDescriptor(
      is_test ? mkldnn::prop_kind::forward_inference
              : mkldnn::prop_kind::forward_training,
      algorithm, md, alpha, beta);

  auto src_memory_p = handler.AcquireSrcMemory(md, to_void_cast<T>(x_data));
  // jczaja: Workaround, src_memory_p is needed in BWD so it has
  // to be accessible under key not dependant on TID
  if (!is_test) {
    dev_ctx.SetBlob(key_src_mem, src_memory_p);
K
Krzysztof Binias 已提交
139
  }
140

141 142 143 144
  auto dst_memory_p =
      handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(y_data));
  auto activation_p = handler.AcquireActivation(dst_memory_p, src_memory_p);

145
  // push primitive to stream and wait until it's executed
146
  std::vector<primitive> pipeline;
147
  pipeline.push_back(*activation_p);
148 149
  stream(stream::kind::eager).submit(pipeline).wait();

150
  y->set_layout(DataLayout::kMKLDNN);
151
  y->set_format(GetMKLDNNFormat(*dst_memory_p));
152 153
}

154 155
template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
156 157
                  mkldnn::algorithm algorithm, const T alpha = 0,
                  const T beta = 0) {
158 159 160
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
  const auto &mkldnn_engine = dev_ctx.GetEngine();

161 162
  const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
  auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
163

164 165
  const T *diff_y_data = diff_y->data<T>();
  T *diff_x_data = diff_x->mutable_data<T>(ctx.GetPlace());
166

167
  std::vector<int> diff_dst_tz = framework::vectorize2int(diff_y->dims());
K
Krzysztof Binias 已提交
168

169 170 171
  auto diff_y_format =
      diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : diff_y->format();

172 173 174 175
  auto diff_dst_md = platform::MKLDNNMemDesc(
      diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format);

  std::string key = platform::MKLDNNHandler::GetHash(
176
      diff_dst_tz, std::to_string(algorithm) + ctx.op().Input("Out"));
177 178 179 180 181

  const std::string key_src_data = key + "@eltwise_fwd_src_data";
  const std::string key_src_layout = key + "@eltwise_fwd_src_layout";

  // Get Data from FWD op
182 183
  const auto p_src_layout =
      std::static_pointer_cast<memory::format>(dev_ctx.GetBlob(key_src_layout));
K
Krzysztof Binias 已提交
184 185
  const auto p_src_data =
      std::static_pointer_cast<T *>(dev_ctx.GetBlob(key_src_data));
186 187
  key += std::to_string(*p_src_layout);
  const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
188
  auto src_memory =
K
Krzysztof Binias 已提交
189
      std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
190 191
  PADDLE_ENFORCE(src_memory != nullptr,
                 "Fail to find src_memory in device context");
192
  src_memory->set_data_handle(*p_src_data);
193

194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
  platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);

  auto diff_dst_memory_p =
      handler.AcquireDiffDstMemory(diff_dst_md, to_void_cast<T>(diff_y_data));

  auto activation_backward_pd =
      handler.AcquireActivationBackwardPrimitiveDescriptor(
          algorithm, diff_dst_md, src_memory->get_primitive_desc().desc(),
          alpha, beta);

  auto diff_src_memory_p =
      handler.AcquireDiffSrcMemoryFromPrimitive(diff_x_data);

  auto activation_backward_p = handler.AcquireActivationBackward(
      diff_src_memory_p, diff_dst_memory_p, src_memory);
209 210

  // push primitive to stream and wait until it's executed
211
  std::vector<primitive> pipeline;
212
  pipeline.push_back(*activation_backward_p);
213 214
  stream(stream::kind::eager).submit(pipeline).wait();

215
  diff_x->set_layout(DataLayout::kMKLDNN);
216
  diff_x->set_format(GetMKLDNNFormat(*diff_src_memory_p));
217 218 219 220
}

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
221
  void operator()(const framework::ExecutionContext &ctx) const {
222 223 224 225 226 227
    eltwise_forward<T>(ctx, algorithm);
  }
};

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
228
  void operator()(const framework::ExecutionContext &ctx) const {
229 230 231 232 233
    eltwise_grad<T>(ctx, algorithm);
  }
};

template <typename T>
T
tensor-tang 已提交
234
using ReluMKLDNNFunctor =
235 236 237
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;

template <typename T>
T
tensor-tang 已提交
238
using TanhMKLDNNFunctor =
239 240 241
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
242
using SqrtMKLDNNFunctor =
243 244 245
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
246
using AbsMKLDNNFunctor =
247 248 249
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_abs>;

template <typename T>
T
tensor-tang 已提交
250
using ReluMKLDNNGradFunctor =
251 252 253
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;

template <typename T>
T
tensor-tang 已提交
254
using TanhMKLDNNGradFunctor =
255 256 257
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
258
using SqrtMKLDNNGradFunctor =
259 260 261
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
262
using AbsMKLDNNGradFunctor =
263 264 265 266 267 268 269 270 271 272 273 274 275
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_abs>;
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
  REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace,       \
                     ops::MKLDNNActivationKernel<ops::functor<float>>);    \
  REGISTER_OP_KERNEL(                                                      \
      act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace,               \
      ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);

276 277 278 279
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro)            \
  __macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
  __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \
  __macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor); \
T
tensor-tang 已提交
280
  __macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);
281 282

FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);