pool_mkldnn_op.cc 15.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

F
From00 已提交
15
#include "paddle/fluid/framework/op_registry.h"
16
#include "paddle/fluid/platform/mkldnn_helper.h"
17
#include "paddle/fluid/platform/mkldnn_reuse.h"
F
From00 已提交
18
#include "paddle/phi/kernels/funcs/pooling.h"
19 20 21 22

namespace paddle {
namespace operators {

23 24 25 26 27 28
using dnnl::memory;
using dnnl::pooling_backward;
using dnnl::pooling_forward;
using dnnl::primitive;
using dnnl::reorder;
using dnnl::stream;
29 30
using framework::DataLayout;
using framework::Tensor;
31
using platform::to_void_cast;
32

33 34
template <typename T>
class PoolingMKLDNNHandler
35 36
    : public platform::MKLDNNHandlerNoCachingT<T,
                                               dnnl::pooling_forward,
37
                                               dnnl::pooling_backward> {
38 39
 public:
  PoolingMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
40 41
                       const dnnl::engine mkldnn_engine,
                       const Tensor* input,
42
                       Tensor* output)
43 44
      : platform::MKLDNNHandlerNoCachingT<T,
                                          dnnl::pooling_forward,
45
                                          dnnl::pooling_backward>(
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
            mkldnn_engine, ctx.GetPlace()) {
    const std::string pooling_type = ctx.Attr<std::string>("pooling_type");

    std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize");
    std::vector<int64_t> ksize(begin(ksize_temp), end(ksize_temp));

    std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
    std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));

    std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));

    const bool global_pooling = ctx.Attr<bool>("global_pooling");
    const std::string padding_algorithm =
        ctx.Attr<std::string>("padding_algorithm");

    // Only 2D pooling is supported now
    PADDLE_ENFORCE_EQ(
64 65
        ksize.size(),
        2,
66 67 68 69
        platform::errors::InvalidArgument(
            "The ksize must be 2D, i.e. 2D pooling, but received %dD.",
            ksize.size()));
    PADDLE_ENFORCE_EQ(
70 71
        pooling_type == "max" || pooling_type == "avg",
        true,
72 73 74 75
        platform::errors::InvalidArgument(
            "The pooling_type must be 'max' or 'avg', but received %s.",
            pooling_type));
    PADDLE_ENFORCE_EQ(
76 77
        input->dims().size(),
        4,
78 79 80 81 82 83
        platform::errors::InvalidArgument(
            "Input dim must be with 4, i.e. NCHW, but received %d.",
            input->dims().size()));

    const auto input_dims = input->dims();
    framework::DDim data_dims =
84
        phi::slice_ddim(input_dims, 2, input_dims.size());
85 86

    if (global_pooling) {
F
From00 已提交
87
      phi::funcs::UpdateKernelSize(&ksize, data_dims);
88
    }
89

90 91 92 93 94 95 96
    phi::funcs::UpdatePadding(&paddings,
                              global_pooling,
                              0,
                              padding_algorithm,
                              data_dims,
                              strides,
                              ksize);
97

98
    const auto is_test = ctx.Attr<bool>("is_test");
99 100 101
    const bool ceil_mode = ctx.Attr<bool>("ceil_mode");
    const auto exclude_padding = ctx.Attr<bool>("exclusive");
    auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
102

103 104
    const auto dt = framework::ToMKLDNNDataType(
        framework::TransToProtoVarType(input->dtype()));
105 106
    const auto src_tz = phi::vectorize(input->dims());
    const auto dst_tz = phi::vectorize(output->dims());
107 108
    const auto dst_md =
        platform::MKLDNNMemDesc(dst_tz, dt, MKLDNNMemoryFormat::any);
109

110
    if (ceil_mode) {
111 112
      CorrectOutputSize(
          src_tz, dst_tz, ksize, paddings, strides, mkldnn_paddings[1]);
113
    }
114 115 116 117

    ComputeAdaptivePoolParameters(ctx, src_tz, &ksize, &strides);

    this->AcquireForwardPrimitiveDescriptor(
118 119
        is_test ? dnnl::prop_kind::forward_inference
                : dnnl::prop_kind::forward_training,
120
        pooling_type == "max"
121 122 123
            ? dnnl::algorithm::pooling_max
            : (exclude_padding ? dnnl::algorithm::pooling_avg_exclude_padding
                               : dnnl::algorithm::pooling_avg_include_padding),
124 125 126 127 128
        input->mem_desc(),
        dst_md,
        strides,
        ksize,
        mkldnn_paddings[0],
129
        mkldnn_paddings[1]);
130 131 132
  }

  PoolingMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
133 134 135 136
                       const dnnl::engine mkldnn_engine,
                       const Tensor* in_x,
                       const Tensor* out_grad,
                       Tensor* in_x_grad)
137

138 139
      : platform::MKLDNNHandlerNoCachingT<T,
                                          dnnl::pooling_forward,
140
                                          dnnl::pooling_backward>(
141 142
            mkldnn_engine, ctx.GetPlace()) {
    PADDLE_ENFORCE_EQ(
143 144
        ctx.Attr<bool>("is_test"),
        false,
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
        platform::errors::InvalidArgument(
            "is_test attribute should be set to False in training phase."));

    std::string pooling_type = ctx.Attr<std::string>("pooling_type");

    std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize");
    std::vector<int64_t> ksize(begin(ksize_temp), end(ksize_temp));

    std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
    std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));

    std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));

    bool global_pooling = ctx.Attr<bool>("global_pooling");
    std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");

    auto in_x_dims = in_x->dims();
163
    framework::DDim data_dims = phi::slice_ddim(in_x_dims, 2, in_x_dims.size());
164 165

    if (global_pooling) {
F
From00 已提交
166
      phi::funcs::UpdateKernelSize(&ksize, data_dims);
167
    }
168

169 170 171 172 173 174 175
    phi::funcs::UpdatePadding(&paddings,
                              global_pooling,
                              0,
                              padding_algorithm,
                              data_dims,
                              strides,
                              ksize);
176

177 178 179
    auto src_tz = phi::vectorize<int64_t>(in_x->dims());
    auto diff_src_tz = phi::vectorize<int64_t>(in_x_grad->dims());
    auto diff_dst_tz = phi::vectorize<int64_t>(out_grad->dims());
180

181 182
    const auto dt = framework::ToMKLDNNDataType(
        framework::TransToProtoVarType(in_x->dtype()));
183 184
    auto dst_md = dnnl::memory::desc(diff_dst_tz, dt, MKLDNNMemoryFormat::any);
    auto diff_src_md = dnnl::memory::desc(
185 186 187 188 189 190
        diff_src_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);

    auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
    const bool ceil_mode = ctx.Attr<bool>("ceil_mode");

    if (ceil_mode) {
191 192
      CorrectOutputSize(
          src_tz, diff_dst_tz, ksize, paddings, strides, mkldnn_paddings[1]);
193
    }
194 195 196 197 198
    ComputeAdaptivePoolParameters(ctx, diff_src_tz, &ksize, &strides);

    const auto exclude_padding = ctx.Attr<bool>("exclusive");

    this->AcquireForwardPrimitiveDescriptor(
199
        dnnl::prop_kind::forward_training,
200
        pooling_type == "max"
201 202 203
            ? dnnl::algorithm::pooling_max
            : (exclude_padding ? dnnl::algorithm::pooling_avg_exclude_padding
                               : dnnl::algorithm::pooling_avg_include_padding),
204 205 206 207 208
        in_x->mem_desc(),
        dst_md,
        strides,
        ksize,
        mkldnn_paddings[0],
209
        mkldnn_paddings[1]);
210 211 212

    this->AcquireBackwardPrimitiveDescriptor(
        pooling_type == "max"
213 214 215
            ? dnnl::algorithm::pooling_max
            : (exclude_padding ? dnnl::algorithm::pooling_avg_exclude_padding
                               : dnnl::algorithm::pooling_avg_include_padding),
216 217 218 219 220
        diff_src_md,
        out_grad->mem_desc(),
        strides,
        ksize,
        mkldnn_paddings[0],
221
        mkldnn_paddings[1]);
222 223
  }

224
  std::shared_ptr<dnnl::memory> AcquireWorkspaceMemory(
225 226
      const platform::MKLDNNDeviceContext& dev_ctx,
      const std::string& unique_name) {
227
    dnnl::memory::desc workspace_md = this->fwd_pd_->workspace_desc();
228
    // Pooling Workspace has to be passed to Grad op that
229 230
    // may be executed by diffrent thread, hence
    // for that one we use key that does not contain TID
231 232 233 234 235
    std::string workspace_key = platform::CreateKey(dev_ctx,
                                                    workspace_md.dims(),
                                                    workspace_md.data_type(),
                                                    unique_name,
                                                    "@wrk");
236 237
    auto mem_p =
        std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(workspace_key));
238 239 240 241
    if (mem_p == nullptr) {
      static std::mutex acquire_barrier;
      std::lock_guard<std::mutex> block_threads_until_finish_this_job(
          acquire_barrier);
242
      mem_p = std::static_pointer_cast<dnnl::memory>(
243
          dev_ctx.GetBlob(workspace_key));
244
      if (mem_p == nullptr) {
245
        mem_p = std::make_shared<dnnl::memory>(workspace_md, this->engine_);
246
        dev_ctx.SetBlob(workspace_key, mem_p);
247 248 249 250 251 252 253
      }
    }
    return mem_p;
  }

  static void ComputeAdaptivePoolParameters(
      const paddle::framework::ExecutionContext& ctx,
254 255
      const std::vector<int64_t>& src_tz,
      std::vector<int64_t>* ksize,
256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
      std::vector<int64_t>* strides) {
    if (ctx.Attr<bool>("adaptive")) {
      // https://github.com/oneapi-src/oneDNN/tree/bkocot/adaptive-pooling/rfcs/20200818-adaptive-pooling
      auto IH = static_cast<double>(src_tz[src_tz.size() - 2]);
      auto IW = static_cast<double>(src_tz[src_tz.size() - 1]);
      auto OH = static_cast<double>(ksize->at(0));
      auto OW = static_cast<double>(ksize->at(1));

      strides->at(0) =
          static_cast<int64_t>(floor((IH * 2.0) / OH) - floor(IH / OH));
      strides->at(1) =
          static_cast<int64_t>(floor((IW * 2.0) / OW) - floor(IW / OW));
      ksize->at(0) =
          static_cast<int64_t>(ceil((IH * 2.0) / OH) - floor(IH / OH));
      ksize->at(1) =
          static_cast<int64_t>(ceil((IW * 2.0) / OW) - floor(IW / OW));
    }
  }

 private:
276 277 278 279
  static inline int ComputeCeiledOutput(int input_size,
                                        int kernel_size,
                                        int padding,
                                        int stride) {
280 281 282 283
    return (input_size - kernel_size + 2 * padding) / stride + 1;
  }

  static inline void CorrectOutputSize(
284 285
      const std::vector<int64_t>& src_tz,
      const std::vector<int64_t>& dst_tz,
286
      const std::vector<int64_t>& kernel_size,
287 288
      const std::vector<int64_t>& paddings,
      const std::vector<int64_t>& strides,
289 290
      std::vector<int64_t>& right_bot_padding) {  // NOLINT
    for (size_t i = 0; i < right_bot_padding.size(); i++) {
291 292
      int desired_size = ComputeCeiledOutput(
          src_tz[i + 2], kernel_size[i], paddings[i], strides[i]);
293 294 295 296 297 298 299
      if (desired_size != dst_tz[i + 2]) {
        right_bot_padding[i] += strides[i] - 1;
      }
    }
  }
};

300 301 302 303
template <typename T>
class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
304 305
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
                      true,
306 307
                      paddle::platform::errors::PreconditionNotMet(
                          "Operator DNNL Pool must use CPUPlace"));
308 309 310 311 312 313
    auto& dev_ctx =
        ctx.template device_context<platform::MKLDNNDeviceContext>();

    const Tensor* input = ctx.Input<Tensor>("X");
    Tensor* output = ctx.Output<Tensor>("Out");

314
    PoolingMKLDNNHandler<T> handler(ctx, dev_ctx.GetEngine(), input, output);
315 316 317 318

    auto src_memory = handler.AcquireSrcMemory(input);
    auto dst_memory = handler.AcquireDstMemory(output);

A
Adam 已提交
319
    auto pool_p = handler.AcquireForwardPrimitive();
320

321
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
322 323
    if ((ctx.Attr<bool>("is_test") == false) &&
        (ctx.Attr<std::string>("pooling_type") == "max")) {
324
      // Training
325 326
      auto workspace_memory =
          handler.AcquireWorkspaceMemory(dev_ctx, ctx.OutputName("Out"));
327 328 329 330
      pool_p->execute(astream,
                      {{DNNL_ARG_SRC, *src_memory},
                       {DNNL_ARG_DST, *dst_memory},
                       {DNNL_ARG_WORKSPACE, *workspace_memory}});
331 332
    } else {
      // Inference
333 334
      pool_p->execute(
          astream, {{DNNL_ARG_SRC, *src_memory}, {DNNL_ARG_DST, *dst_memory}});
335
    }
A
Adam 已提交
336
    astream.wait();
337

338
    output->set_mem_desc(dst_memory->get_desc());
339 340 341 342 343 344 345
  }
};

template <typename T>
class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
 public:
  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
346 347
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
                      true,
348 349
                      paddle::platform::errors::PreconditionNotMet(
                          "Operator DNNL PoolGrad must use CPUPlace"));
350 351 352 353 354 355 356
    const Tensor* in_x = ctx.Input<Tensor>("X");
    const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
    Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));

    auto& dev_ctx =
        ctx.template device_context<platform::MKLDNNDeviceContext>();

357 358
    PoolingMKLDNNHandler<T> handler(
        ctx, dev_ctx.GetEngine(), in_x, out_grad, in_x_grad);
359 360 361 362

    auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad);
    auto diff_src_memory = handler.AcquireDiffSrcMemory(in_x_grad);

A
Adam 已提交
363
    auto pool_bwd_p = handler.AcquireBackwardPrimitive();
364

365
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
366
    if (ctx.Attr<std::string>("pooling_type") == "max") {
367
      // Max - pooling needs Workspace
368 369
      auto workspace_memory =
          handler.AcquireWorkspaceMemory(dev_ctx, ctx.InputName("Out"));
370 371 372 373
      pool_bwd_p->execute(astream,
                          {{DNNL_ARG_DIFF_SRC, *diff_src_memory},
                           {DNNL_ARG_DIFF_DST, *diff_dst_memory},
                           {DNNL_ARG_WORKSPACE, *workspace_memory}});
374 375
    } else {
      // Average Pooling
376 377 378
      pool_bwd_p->execute(astream,
                          {{DNNL_ARG_DIFF_SRC, *diff_src_memory},
                           {DNNL_ARG_DIFF_DST, *diff_dst_memory}});
379
    }
A
Adam 已提交
380
    astream.wait();
381

382
    in_x_grad->set_mem_desc(diff_src_memory->get_desc());
383 384 385 386 387 388
  }  // Compute()
};

}  // namespace operators
}  // namespace paddle

389 390
namespace ops = paddle::operators;

391 392 393
REGISTER_OP_KERNEL(pool2d,
                   MKLDNN,
                   ::paddle::platform::CPUPlace,
X
xiaoli.liu@intel.com 已提交
394 395
                   ops::PoolMKLDNNOpKernel<float>,
                   ops::PoolMKLDNNOpKernel<int8_t>,
396 397
                   ops::PoolMKLDNNOpKernel<uint8_t>,
                   ops::PoolMKLDNNOpKernel<paddle::platform::bfloat16>);
X
xiaoli.liu@intel.com 已提交
398

399 400 401
REGISTER_OP_KERNEL(pool2d_grad,
                   MKLDNN,
                   ::paddle::platform::CPUPlace,
A
arlesniak 已提交
402 403
                   ops::PoolMKLDNNGradOpKernel<float>,
                   ops::PoolMKLDNNGradOpKernel<paddle::platform::bfloat16>);