activation_mkldnn_op.cc 9.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/fluid/operators/activation_op.h"
16
#include "paddle/fluid/platform/mkldnn_reuse.h"
17 18 19 20

namespace paddle {
namespace operators {

21 22 23 24 25 26 27 28
using framework::DataLayout;
using framework::Tensor;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
29 30

namespace {
K
Krzysztof Binias 已提交
31 32
std::string gethash(const mkldnn::memory::dims &operand_dims,
                    const mkldnn::algorithm algorithm) {
K
Krzysztof Binias 已提交
33 34 35 36 37 38 39 40
  auto dim2str = [](const mkldnn::memory::dims &operand_dims) {
    std::string dstr = "";
    for (size_t i = 0; i < operand_dims.size(); ++i) {
      dstr += std::to_string(operand_dims[i]) + "-";
    }
    return dstr;
  };
  return dim2str(operand_dims) + std::to_string(algorithm);
K
Krzysztof Binias 已提交
41
}
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
}  // namespace

template <typename Functor>
class MKLDNNActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *x = ctx.Input<Tensor>("X");
    PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
                       x->format() != memory::format::format_undef,
                   "Wrong layout/format set for Input x tensor");

    Functor functor;
    functor(ctx);
  }
};
K
Krzysztof Binias 已提交
58

59 60 61 62 63 64 65 66 67 68
template <typename Functor>
class MKLDNNActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
    PADDLE_ENFORCE(diff_y->layout() == DataLayout::kMKLDNN &&
                       diff_y->format() != memory::format::format_undef,
                   "Wrong layout/format set for Input OutGrad tensor");

69 70 71 72
    PADDLE_ENFORCE(
        !ctx.Attr<bool>("is_test"),
        "is_test attribute should be set to False in training phase.");

73 74 75 76 77 78 79
    Functor functor;
    functor(ctx);
  }
};

template <typename T>
void eltwise_forward(const framework::ExecutionContext &ctx,
A
Adam 已提交
80
                     mkldnn::algorithm algorithm) {
81 82 83 84 85
  PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
                 "It must use CPUPlace.");
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
  const auto &mkldnn_engine = dev_ctx.GetEngine();

86 87
  const auto *x = ctx.Input<Tensor>("X");
  auto *y = ctx.Output<Tensor>("Out");
88

89
  const T *x_data = x->data<T>();
90

A
Adam 已提交
91 92 93
  const T alpha = ctx.op().HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
  const T beta = ctx.op().HasAttr("beta") ? ctx.Attr<T>("beta") : 0;

Y
Yihua Xu 已提交
94 95 96 97
  PADDLE_ENFORCE(
      x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4,
      "Input dim must be with 2, 3 or 4");

98 99
  std::vector<int> src_tz = framework::vectorize2int(x->dims());

100 101
  auto src_format =
      src_tz.size() == 2 ? mkldnn::memory::format::nc : x->format();
102

103 104
  bool is_test = ctx.Attr<bool>("is_test");

105 106
  std::string key = platform::ActivationMKLDNNHandler::GetHash(
      src_tz, algorithm, src_format, alpha, beta, ctx.op().Input("X"));
K
Krzysztof Binias 已提交
107

108 109 110 111 112 113 114 115 116 117 118
  platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);

  auto md = platform::MKLDNNMemDesc(src_tz, platform::MKLDNNGetDataType<T>(),
                                    src_format);

  auto activation_pd = handler.AcquireActivationPrimitiveDescriptor(
      is_test ? mkldnn::prop_kind::forward_inference
              : mkldnn::prop_kind::forward_training,
      algorithm, md, alpha, beta);

  auto src_memory_p = handler.AcquireSrcMemory(md, to_void_cast<T>(x_data));
119

120
  auto dst_memory_p =
121
      handler.AcquireDstMemoryFromPrimitive<T>(y, ctx.GetPlace());
122 123
  auto activation_p = handler.AcquireActivation(dst_memory_p, src_memory_p);

124
  // push primitive to stream and wait until it's executed
125
  std::vector<primitive> pipeline;
126
  pipeline.push_back(*activation_p);
127 128
  stream(stream::kind::eager).submit(pipeline).wait();

129
  y->set_layout(DataLayout::kMKLDNN);
130
  y->set_format(GetMKLDNNFormat(*dst_memory_p));
131 132
}

133 134
template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
A
Adam 已提交
135
                  mkldnn::algorithm algorithm) {
136 137 138
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
  const auto &mkldnn_engine = dev_ctx.GetEngine();

139 140 141
  const auto *x = ctx.Input<Tensor>("X");
  const T *x_data = x->data<T>();

142 143
  const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
  auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
144

145 146
  const T *diff_y_data = diff_y->data<T>();
  T *diff_x_data = diff_x->mutable_data<T>(ctx.GetPlace());
147

A
Adam 已提交
148 149 150
  const T alpha = ctx.op().HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
  const T beta = ctx.op().HasAttr("beta") ? ctx.Attr<T>("beta") : 0;

151
  std::vector<int> diff_dst_tz = framework::vectorize2int(diff_y->dims());
K
Krzysztof Binias 已提交
152

153 154 155 156
  // diff_dst and src dims should be the same
  auto src_format =
      diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : x->format();

157 158 159
  auto diff_y_format =
      diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : diff_y->format();

160 161 162
  auto diff_dst_md = platform::MKLDNNMemDesc(
      diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format);

163 164
  std::string key = platform::ActivationMKLDNNHandler::GetHash(
      diff_dst_tz, algorithm, src_format, alpha, beta, ctx.op().Input("X"));
165 166

  const std::string key_src_data = key + "@eltwise_fwd_src_data";
167 168 169

  auto src_md = platform::MKLDNNMemDesc(
      diff_dst_tz, platform::MKLDNNGetDataType<T>(), src_format);
170

171 172
  platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);

173 174
  auto src_memory_p = handler.AcquireSrcMemory(src_md, to_void_cast<T>(x_data));

175 176 177 178 179
  auto diff_dst_memory_p =
      handler.AcquireDiffDstMemory(diff_dst_md, to_void_cast<T>(diff_y_data));

  auto activation_backward_pd =
      handler.AcquireActivationBackwardPrimitiveDescriptor(
180
          algorithm, diff_dst_md, src_memory_p->get_primitive_desc().desc(),
181 182 183 184 185 186
          alpha, beta);

  auto diff_src_memory_p =
      handler.AcquireDiffSrcMemoryFromPrimitive(diff_x_data);

  auto activation_backward_p = handler.AcquireActivationBackward(
187
      diff_src_memory_p, diff_dst_memory_p, src_memory_p);
188 189

  // push primitive to stream and wait until it's executed
190
  std::vector<primitive> pipeline;
191
  pipeline.push_back(*activation_backward_p);
192 193
  stream(stream::kind::eager).submit(pipeline).wait();

194
  diff_x->set_layout(DataLayout::kMKLDNN);
195
  diff_x->set_format(GetMKLDNNFormat(*diff_src_memory_p));
196 197 198 199
}

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
200
  void operator()(const framework::ExecutionContext &ctx) const {
201 202 203 204 205 206
    eltwise_forward<T>(ctx, algorithm);
  }
};

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
207
  void operator()(const framework::ExecutionContext &ctx) const {
208 209 210 211 212
    eltwise_grad<T>(ctx, algorithm);
  }
};

template <typename T>
T
tensor-tang 已提交
213
using ReluMKLDNNFunctor =
214 215 216
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;

template <typename T>
T
tensor-tang 已提交
217
using TanhMKLDNNFunctor =
218 219 220
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
221
using SqrtMKLDNNFunctor =
222 223 224
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
225
using AbsMKLDNNFunctor =
226 227 228
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_abs>;

template <typename T>
T
tensor-tang 已提交
229
using ReluMKLDNNGradFunctor =
230 231 232
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;

template <typename T>
T
tensor-tang 已提交
233
using TanhMKLDNNGradFunctor =
234 235 236
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
237
using SqrtMKLDNNGradFunctor =
238 239 240
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
241
using AbsMKLDNNGradFunctor =
242 243 244 245 246 247 248 249 250 251 252 253 254
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_abs>;
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
  REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace,       \
                     ops::MKLDNNActivationKernel<ops::functor<float>>);    \
  REGISTER_OP_KERNEL(                                                      \
      act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace,               \
      ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);

A
Adam 已提交
255 256 257 258 259
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro)                  \
  __macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor);       \
  __macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
  __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor);       \
  __macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor);       \
T
tensor-tang 已提交
260
  __macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);
261 262

FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);