activation_mkldnn_op.cc 12.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/fluid/operators/activation_op.h"
K
Krzysztof Binias 已提交
16
#include "paddle/fluid/platform/mkldnn_helper.h"
17 18 19 20

namespace paddle {
namespace operators {

21 22 23 24 25 26 27 28
using framework::DataLayout;
using framework::Tensor;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
29 30

namespace {
K
Krzysztof Binias 已提交
31 32
std::string gethash(const mkldnn::memory::dims &operand_dims,
                    const mkldnn::algorithm algorithm) {
K
Krzysztof Binias 已提交
33 34 35 36 37 38 39 40
  auto dim2str = [](const mkldnn::memory::dims &operand_dims) {
    std::string dstr = "";
    for (size_t i = 0; i < operand_dims.size(); ++i) {
      dstr += std::to_string(operand_dims[i]) + "-";
    }
    return dstr;
  };
  return dim2str(operand_dims) + std::to_string(algorithm);
K
Krzysztof Binias 已提交
41
}
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
}  // namespace

template <typename Functor>
class MKLDNNActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *x = ctx.Input<Tensor>("X");
    PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
                       x->format() != memory::format::format_undef,
                   "Wrong layout/format set for Input x tensor");

    Functor functor;
    functor(ctx);
  }
};
K
Krzysztof Binias 已提交
58

59 60 61 62 63 64 65 66 67 68
template <typename Functor>
class MKLDNNActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
    PADDLE_ENFORCE(diff_y->layout() == DataLayout::kMKLDNN &&
                       diff_y->format() != memory::format::format_undef,
                   "Wrong layout/format set for Input OutGrad tensor");

69 70 71 72
    PADDLE_ENFORCE(
        !ctx.Attr<bool>("is_test"),
        "is_test attribute should be set to False in training phase.");

73 74 75 76 77 78 79 80 81
    Functor functor;
    functor(ctx);
  }
};

template <typename T>
void eltwise_forward(const framework::ExecutionContext &ctx,
                     mkldnn::algorithm algorithm, const T alpha = 0,
                     const T beta = 0) {
82 83 84 85 86
  PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
                 "It must use CPUPlace.");
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
  const auto &mkldnn_engine = dev_ctx.GetEngine();

87 88
  const auto *x = ctx.Input<Tensor>("X");
  auto *y = ctx.Output<Tensor>("Out");
89

90 91
  const T *x_data = x->data<T>();
  T *y_data = y->mutable_data<T>(ctx.GetPlace());
92

Y
Yihua Xu 已提交
93 94 95 96
  PADDLE_ENFORCE(
      x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4,
      "Input dim must be with 2, 3 or 4");

97 98
  std::vector<int> src_tz = framework::vectorize2int(x->dims());

99
  auto src_format = x->format();
100

K
Krzysztof Binias 已提交
101
  const std::string key = gethash(src_tz, algorithm);
K
Krzysztof Binias 已提交
102 103
  const std::string key_src_data =
      key + ctx.op().Output("Out") + "@eltwise_fwd_src_data";
104 105 106 107 108 109 110 111
  const std::string key_src_layout =
      key + ctx.op().Output("Out") + "@eltwise_fwd_src_layout";
  const std::string key_with_layout = key + std::to_string(src_format);
  const std::string key_src_mem = key_with_layout + "@eltwise_fwd_src_mem";
  const std::string key_dst_mem = key_with_layout + "@eltwise_fwd_dst_mem";
  const std::string key_fwd = key_with_layout + "@eltwise_fwd";
  const std::string key_fwd_pd = key_with_layout + "@eltwise_fwd_pd";

112 113
  bool is_test = ctx.Attr<bool>("is_test");

114 115 116
  // save input data and layout to be referred in backward path
  auto p_src_data = std::make_shared<const T *>(x_data);
  auto p_src_layout = std::make_shared<memory::format>(src_format);
117 118 119 120
  if (!is_test) {
    dev_ctx.SetBlob(key_src_data, p_src_data);
    dev_ctx.SetBlob(key_src_layout, p_src_layout);
  }
K
Krzysztof Binias 已提交
121

K
Krzysztof Binias 已提交
122 123
  auto p_fwd = std::static_pointer_cast<mkldnn::eltwise_forward>(
      dev_ctx.GetBlob(key_fwd));
K
Krzysztof Binias 已提交
124

125
  std::shared_ptr<memory> dst_memory;
K
Krzysztof Binias 已提交
126

K
Krzysztof Binias 已提交
127
  if (p_fwd == nullptr) {
128 129
    // create mkldnn memory for input X
    auto src_memory = std::shared_ptr<memory>(
130
        new memory(x->get_mkldnn_prim_desc(), to_void_cast(x_data)));
131 132 133 134
    // save src_memory to be referred in backward path
    dev_ctx.SetBlob(key_src_mem, src_memory);

    // create primitive descriptor for activation forward and save it
135 136 137
    auto mkldnn_forward_prop_kind = is_test
                                        ? mkldnn::prop_kind::forward_inference
                                        : mkldnn::prop_kind::forward_training;
138
    auto forward_desc = mkldnn::eltwise_forward::desc(
139
        mkldnn_forward_prop_kind, algorithm,
140 141 142 143 144
        src_memory->get_primitive_desc().desc(), alpha, beta);
    auto forward_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
        forward_desc, mkldnn_engine);

    // save prim desc into global device context to be referred in backward path
145
    if (!is_test) dev_ctx.SetBlob(key_fwd_pd, forward_pd);
146 147 148 149 150 151 152 153 154 155

    // create mkldnn memory for output y
    dst_memory =
        std::make_shared<memory>(forward_pd->dst_primitive_desc(), y_data);

    dev_ctx.SetBlob(key_dst_mem, dst_memory);

    // create activation primitive
    p_fwd = std::make_shared<mkldnn::eltwise_forward>(*forward_pd, *src_memory,
                                                      *dst_memory);
K
Krzysztof Binias 已提交
156 157
    dev_ctx.SetBlob(key_fwd, p_fwd);
  } else {
K
Krzysztof Binias 已提交
158
    // primitives already exist
159
    auto src_memory =
K
Krzysztof Binias 已提交
160
        std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
161 162 163
    PADDLE_ENFORCE(src_memory != nullptr,
                   "Fail to find eltwise src_memory in device context.");
    dst_memory =
K
Krzysztof Binias 已提交
164
        std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_dst_mem));
165 166
    PADDLE_ENFORCE(dst_memory != nullptr,
                   "Fail to find eltwise dst_memory in device context.");
K
Krzysztof Binias 已提交
167

168 169
    src_memory->set_data_handle(platform::to_void_cast(x_data));
    dst_memory->set_data_handle(y_data);
K
Krzysztof Binias 已提交
170
  }
171 172

  // push primitive to stream and wait until it's executed
173 174 175 176
  std::vector<primitive> pipeline;
  pipeline.push_back(*p_fwd);
  stream(stream::kind::eager).submit(pipeline).wait();

177
  y->set_mkldnn_prim_desc(dst_memory->get_primitive_desc());
178 179
}

180 181 182 183
template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
                  mkldnn::algorithm algorithm, const T alpha = 0,
                  const T beta = 0) {
184 185 186
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
  const auto &mkldnn_engine = dev_ctx.GetEngine();

187 188
  const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
  auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
189

190 191
  const T *diff_y_data = diff_y->data<T>();
  T *diff_x_data = diff_x->mutable_data<T>(ctx.GetPlace());
192

193
  std::vector<int> diff_dst_tz = framework::vectorize2int(diff_y->dims());
K
Krzysztof Binias 已提交
194

195
  const std::string key = gethash(diff_dst_tz, algorithm);
K
Krzysztof Binias 已提交
196 197
  const std::string key_src_data =
      key + ctx.op().Input("Out") + "@eltwise_fwd_src_data";
198 199 200 201 202 203 204 205
  const std::string key_src_layout =
      key + ctx.op().Input("Out") + "@eltwise_fwd_src_layout";
  const auto p_src_layout =
      std::static_pointer_cast<memory::format>(dev_ctx.GetBlob(key_src_layout));
  const std::string key_src_mem =
      key + std::to_string(*p_src_layout) + "@eltwise_fwd_src_mem";
  const std::string key_fwd_pd =
      key + std::to_string(*p_src_layout) + "@eltwise_fwd_pd";
206 207
  const std::string key_with_layouts = key + std::to_string(*p_src_layout) +
                                       "-" + std::to_string(diff_y->format());
208 209 210 211 212 213
  const std::string key_diff_src_mem =
      key_with_layouts + "@eltwise_diff_src_mem";
  const std::string key_diff_dst_mem =
      key_with_layouts + "@eltwise_diff_dst_mem";
  const std::string key_grad = key_with_layouts + "@eltwise_grad";

K
Krzysztof Binias 已提交
214 215 216
  const auto p_src_data =
      std::static_pointer_cast<T *>(dev_ctx.GetBlob(key_src_data));

217
  auto src_memory =
K
Krzysztof Binias 已提交
218
      std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
219 220
  PADDLE_ENFORCE(src_memory != nullptr,
                 "Fail to find src_memory in device context");
221
  src_memory->set_data_handle(*p_src_data);
222 223

  std::shared_ptr<memory> diff_src_memory;
K
Krzysztof Binias 已提交
224

225
  auto p_grad = std::static_pointer_cast<mkldnn::eltwise_backward>(
K
Krzysztof Binias 已提交
226 227 228
      dev_ctx.GetBlob(key_grad));

  if (p_grad == nullptr) {
229 230
    // create mkldnn memory for input diff_y
    auto diff_dst_memory = std::shared_ptr<memory>(
231
        new memory(diff_y->get_mkldnn_prim_desc(), to_void_cast(diff_y_data)));
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
    dev_ctx.SetBlob(key_diff_dst_mem, diff_dst_memory);

    // retrieve eltwise primitive desc from device context
    auto forward_pd =
        std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
            dev_ctx.GetBlob(key_fwd_pd));
    PADDLE_ENFORCE(forward_pd != nullptr,
                   "Fail to find eltwise_fwd_pd in device context");

    // ceate primitive descriptor for activation backward
    auto backward_desc = mkldnn::eltwise_backward::desc(
        algorithm, diff_dst_memory->get_primitive_desc().desc(),
        src_memory->get_primitive_desc().desc(), alpha, beta);
    auto backward_pd = mkldnn::eltwise_backward::primitive_desc(
        backward_desc, mkldnn_engine, *forward_pd);

    // create mkldnn memory for output diff_src
    diff_src_memory = std::make_shared<memory>(
        backward_pd.diff_src_primitive_desc(), diff_x_data);
    dev_ctx.SetBlob(key_diff_src_mem, diff_src_memory);

    // create activation backward primitive
K
Krzysztof Binias 已提交
254
    p_grad = std::make_shared<mkldnn::eltwise_backward>(
255 256
        backward_pd, *src_memory, *diff_dst_memory, *diff_src_memory);
    dev_ctx.SetBlob(key_grad, p_grad);
K
Krzysztof Binias 已提交
257 258
  } else {
    // primitives already exist
259
    diff_src_memory = std::static_pointer_cast<mkldnn::memory>(
K
Krzysztof Binias 已提交
260
        dev_ctx.GetBlob(key_diff_src_mem));
261
    auto diff_dst_memory = std::static_pointer_cast<mkldnn::memory>(
K
Krzysztof Binias 已提交
262 263
        dev_ctx.GetBlob(key_diff_dst_mem));

264 265 266 267
    diff_src_memory->set_data_handle(
        platform::to_void_reinterpret_cast(diff_x_data));
    diff_dst_memory->set_data_handle(
        platform::to_void_reinterpret_cast(diff_y_data));
K
Krzysztof Binias 已提交
268
  }
269 270

  // push primitive to stream and wait until it's executed
271 272 273 274
  std::vector<primitive> pipeline;
  pipeline.push_back(*p_grad);
  stream(stream::kind::eager).submit(pipeline).wait();

275
  diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc());
276 277 278 279
}

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
280
  void operator()(const framework::ExecutionContext &ctx) const {
281 282 283 284 285 286
    eltwise_forward<T>(ctx, algorithm);
  }
};

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
287
  void operator()(const framework::ExecutionContext &ctx) const {
288 289 290 291 292
    eltwise_grad<T>(ctx, algorithm);
  }
};

template <typename T>
T
tensor-tang 已提交
293
using ReluMKLDNNFunctor =
294 295 296
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;

template <typename T>
T
tensor-tang 已提交
297
using TanhMKLDNNFunctor =
298 299 300
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
301
using SqrtMKLDNNFunctor =
302 303 304
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
305
using AbsMKLDNNFunctor =
306 307 308
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_abs>;

template <typename T>
T
tensor-tang 已提交
309
using ReluMKLDNNGradFunctor =
310 311 312
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;

template <typename T>
T
tensor-tang 已提交
313
using TanhMKLDNNGradFunctor =
314 315 316
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
317
using SqrtMKLDNNGradFunctor =
318 319 320
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
321
using AbsMKLDNNGradFunctor =
322 323 324 325 326 327 328 329 330 331 332 333 334
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_abs>;
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
  REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace,       \
                     ops::MKLDNNActivationKernel<ops::functor<float>>);    \
  REGISTER_OP_KERNEL(                                                      \
      act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace,               \
      ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);

K
Krzysztof Binias 已提交
335
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro)            \
T
tensor-tang 已提交
336 337 338 339
  __macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
  __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \
  __macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor); \
  __macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);
340 341

FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);