activation_mkldnn_op.cc 8.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/fluid/operators/activation_op.h"
16
#include "paddle/fluid/platform/mkldnn_reuse.h"
17 18 19 20

namespace paddle {
namespace operators {

21 22 23 24 25 26 27 28
using framework::DataLayout;
using framework::Tensor;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
29

30 31 32 33 34 35
template <typename Functor>
class MKLDNNActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *x = ctx.Input<Tensor>("X");
36 37
    PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN,
                      "Wrong layout set for X tensor");
A
Adam 已提交
38
    PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::undef,
39
                      "Wrong format set for X tensor");
40 41 42 43 44

    Functor functor;
    functor(ctx);
  }
};
K
Krzysztof Binias 已提交
45

46 47 48 49 50 51
template <typename Functor>
class MKLDNNActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
52 53
    PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
                      "Wrong layout set for Input OutGrad tensor");
A
Adam 已提交
54
    PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
55
                      "Wrong format set for Input OutGrad tensor");
56

57 58
    PADDLE_ENFORCE_EQ(
        ctx.Attr<bool>("is_test"), false,
59 60
        "is_test attribute should be set to False in training phase.");

61 62 63 64 65 66 67
    Functor functor;
    functor(ctx);
  }
};

template <typename T>
void eltwise_forward(const framework::ExecutionContext &ctx,
A
Adam 已提交
68
                     mkldnn::algorithm algorithm) {
69 70 71 72
  PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
                 "It must use CPUPlace.");
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();

73 74
  const auto *x = ctx.Input<Tensor>("X");
  auto *y = ctx.Output<Tensor>("Out");
75

76 77 78 79 80 81 82
  T alpha = ctx.HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
  T beta = ctx.HasAttr("beta") ? ctx.Attr<T>("beta") : 0;

  // paddle uses beta but mkldnn uses alpha for swish
  if (algorithm == mkldnn::algorithm::eltwise_swish) {
    std::swap(alpha, beta);
  }
A
Adam 已提交
83

Y
Yihua Xu 已提交
84 85 86 87
  PADDLE_ENFORCE(
      x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4,
      "Input dim must be with 2, 3 or 4");

A
Adam 已提交
88
  auto src_tz = framework::vectorize<int64_t>(x->dims());
89

90
  auto src_format = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format();
91

92 93
  bool is_test = ctx.Attr<bool>("is_test");

94 95
  platform::ActivationMKLDNNHandler<T> handler(
      src_tz, algorithm, alpha, beta, src_format, is_test, dev_ctx,
H
hong 已提交
96
      ctx.GetPlace(), ctx.InputName("X"));
97

98 99
  auto src_memory_p = handler.AcquireSrcMemory(x);
  auto dst_memory_p = handler.AcquireDstMemory(y);
A
Adam 已提交
100
  auto activation_p = handler.AcquireForwardPrimitive();
101

A
Adam 已提交
102 103 104 105
  mkldnn::stream astream(dev_ctx.GetEngine());
  activation_p->execute(astream, {{MKLDNN_ARG_FROM, *src_memory_p},
                                  {MKLDNN_ARG_TO, *dst_memory_p}});
  astream.wait();
106

107
  y->set_layout(DataLayout::kMKLDNN);
108
  y->set_format(GetMKLDNNFormat(*dst_memory_p));
109 110
}

111 112
template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
A
Adam 已提交
113
                  mkldnn::algorithm algorithm) {
114 115
  auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();

116
  const auto *x = ctx.Input<Tensor>("X");
117 118
  const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
  auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
119

120 121 122 123 124 125 126
  T alpha = ctx.HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
  T beta = ctx.HasAttr("beta") ? ctx.Attr<T>("beta") : 0;

  // paddle uses beta but mkldnn uses alpha for swish
  if (algorithm == mkldnn::algorithm::eltwise_swish) {
    std::swap(alpha, beta);
  }
A
Adam 已提交
127

A
Adam 已提交
128
  auto diff_dst_tz = framework::vectorize<int64_t>(diff_y->dims());
K
Krzysztof Binias 已提交
129

130 131
  // diff_dst and src dims should be the same
  auto src_format =
132
      diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format();
133

134
  auto diff_y_format =
135
      diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : diff_y->format();
136

137 138
  platform::ActivationMKLDNNHandler<T> handler(
      diff_dst_tz, algorithm, alpha, beta, src_format, diff_y_format, dev_ctx,
H
hong 已提交
139
      ctx.GetPlace(), ctx.InputName("X"));
140

141 142 143
  auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
  auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y);
  auto diff_src_memory_p = handler.AcquireDiffSrcMemory(diff_x);
A
Adam 已提交
144 145 146 147 148 149 150 151
  auto activation_backward_p = handler.AcquireBackwardPrimitive();

  mkldnn::stream astream(dev_ctx.GetEngine());
  activation_backward_p->execute(astream,
                                 {{MKLDNN_ARG_SRC, *src_memory_p},
                                  {MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p},
                                  {MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}});
  astream.wait();
152

153
  diff_x->set_layout(DataLayout::kMKLDNN);
154
  diff_x->set_format(GetMKLDNNFormat(*diff_src_memory_p));
155 156 157 158
}

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
159
  void operator()(const framework::ExecutionContext &ctx) const {
160 161 162 163 164 165
    eltwise_forward<T>(ctx, algorithm);
  }
};

template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
166
  void operator()(const framework::ExecutionContext &ctx) const {
167 168 169 170 171
    eltwise_grad<T>(ctx, algorithm);
  }
};

template <typename T>
T
tensor-tang 已提交
172
using ReluMKLDNNFunctor =
173 174
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;

175 176 177 178
template <typename T>
using SwishMKLDNNFunctor =
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_swish>;

179
template <typename T>
T
tensor-tang 已提交
180
using TanhMKLDNNFunctor =
181 182 183
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
184
using SqrtMKLDNNFunctor =
185 186 187
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
188
using AbsMKLDNNFunctor =
189 190 191
    MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_abs>;

template <typename T>
T
tensor-tang 已提交
192
using ReluMKLDNNGradFunctor =
193 194
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;

195 196 197 198
template <typename T>
using SwishMKLDNNGradFunctor =
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_swish>;

199
template <typename T>
T
tensor-tang 已提交
200
using TanhMKLDNNGradFunctor =
201 202 203
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>;

template <typename T>
T
tensor-tang 已提交
204
using SqrtMKLDNNGradFunctor =
205 206 207
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_sqrt>;

template <typename T>
T
tensor-tang 已提交
208
using AbsMKLDNNGradFunctor =
209 210 211 212 213 214 215 216 217 218 219 220 221
    MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_abs>;
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
  REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace,       \
                     ops::MKLDNNActivationKernel<ops::functor<float>>);    \
  REGISTER_OP_KERNEL(                                                      \
      act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace,               \
      ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);

A
Adam 已提交
222 223 224
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro)                  \
  __macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor);       \
  __macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
225
  __macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor);    \
A
Adam 已提交
226 227
  __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor);       \
  __macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradFunctor);       \
T
tensor-tang 已提交
228
  __macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);
229 230

FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);