conv_transpose_mkldnn_op.cc 17.2 KB
Newer Older
J
Jacek Czaja 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15
#include "boost/optional.hpp"
J
Jacek Czaja 已提交
16 17 18
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
19
#include "paddle/fluid/operators/conv_op.h"
J
Jacek Czaja 已提交
20 21 22 23 24 25 26 27
#include "paddle/fluid/platform/mkldnn_reuse.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using framework::DataLayout;

28
inline dnnl::memory::dims GetWeightsTz(const Tensor* filter, const int groups) {
29
  auto weights_tz = phi::vectorize(filter->dims());
30
  int g = std::max(groups, 1);
31
  int g_dim = (g > 1) ? 1 : 0;
32
  platform::GetGroupConvWeightsTz(weights_tz, g);
33 34
  // gIOHW -> gOIHW || IOHW -> OIHW
  std::swap(weights_tz[g_dim + 0], weights_tz[g_dim + 1]);
35 36 37 38 39
  return weights_tz;
}

template <typename T, typename K, typename T_out>
class ConvTransposeMKLDNNHandlerT
40
    : public platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward> {
J
Jacek Czaja 已提交
41
 public:
42
  ConvTransposeMKLDNNHandlerT(const framework::ExecutionContext& ctx,
43
                              const dnnl::engine mkldnn_engine,
44 45 46 47
                              const Tensor* input,
                              const Tensor* filter,
                              const Tensor* bias,
                              Tensor* output)
48
      : platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward>(
49 50
            mkldnn_engine, ctx.GetPlace()),
        is_test_(ctx.Attr<bool>("is_test")) {
51 52
    PADDLE_ENFORCE_EQ(is_test_,
                      true,
53 54 55 56 57 58
                      platform::errors::InvalidArgument(
                          "ConvTransposeMKLDNN works only for inference. "
                          "The attribute \'is_test\' value should be set to "
                          "True, but got is_test=False."));

    PADDLE_ENFORCE_EQ(
59 60
        input->layout(),
        DataLayout::kMKLDNN,
61 62
        platform::errors::InvalidArgument(
            "Got wrong layout = %d for Input tensor.", input->layout()));
63 64
    PADDLE_ENFORCE_NE(input->format(),
                      MKLDNNMemoryFormat::undef,
65 66 67 68 69
                      platform::errors::InvalidArgument(
                          "Got wrong format for Input tensor. The input "
                          "format is undefined."));

    PADDLE_ENFORCE_EQ(
70 71
        filter->layout(),
        DataLayout::kMKLDNN,
72
        platform::errors::InvalidArgument(
73
            "The filter tensor's layout should be %d, but got %d.",
74 75 76 77
            DataLayout::kMKLDNN,
            filter->layout()));
    PADDLE_ENFORCE_NE(filter->format(),
                      MKLDNNMemoryFormat::undef,
78 79 80 81
                      platform::errors::InvalidArgument(
                          "Got wrong formats for Filter tensor."));

    PADDLE_ENFORCE_EQ(
82 83
        input->dims().size(),
        4,
84 85 86 87
        platform::errors::InvalidArgument("Input must be with 4 dimensions, "
                                          "i.e. NCHW. but got dimension =%d",
                                          input->dims().size()));
    PADDLE_ENFORCE_EQ(
88 89
        filter->dims().size(),
        4,
90 91 92
        platform::errors::InvalidArgument("Filter must be with 4 dimensions, "
                                          "i.e. OIHW, but got dimension =%d",
                                          filter->dims().size()));
F
FDInSky 已提交
93

94
    if (bias) {
F
FDInSky 已提交
95
      PADDLE_ENFORCE_EQ(
96 97
          bias->layout(),
          DataLayout::kMKLDNN,
98
          platform::errors::InvalidArgument(
99
              "The bias tensor's laytout should be %d, but got %d.",
100 101 102 103
              DataLayout::kMKLDNN,
              bias->layout()));
      PADDLE_ENFORCE_NE(bias->format(),
                        MKLDNNMemoryFormat::undef,
104
                        platform::errors::InvalidArgument(
105
                            "Got wrong format for Bias tensor."));
A
Adam 已提交
106

107
      PADDLE_ENFORCE_EQ(
108 109
          bias->dims().size(),
          1,
110 111 112 113
          platform::errors::InvalidArgument("Bias must only have 1 dimension, "
                                            "i.e. X, but got dimension = %d .",
                                            bias->dims().size()));
    }
114

115
    std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
116
    dnnl::memory::dims strides(begin(strides_temp), end(strides_temp));
117 118

    std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
119
    dnnl::memory::dims paddings(begin(paddings_temp), end(paddings_temp));
120 121

    std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
122
    dnnl::memory::dims dilations(begin(dilations_temp), end(dilations_temp));
123 124 125 126 127

    int groups = ctx.Attr<int>("groups");
    std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");

    PADDLE_ENFORCE_EQ(
128 129
        strides.size(),
        2,
130 131 132 133
        platform::errors::Unimplemented(
            "Now we only support 2d oneDNN convolution transpose op"));

    const auto& input_dims = input->dims();
134
    const auto data_dims = phi::slice_ddim(input_dims, 2, input_dims.size());
135 136
    const auto& filter_dims = filter->dims();
    const auto filter_data_dims =
137
        phi::slice_ddim(filter_dims, 2, filter_dims.size());
138

139
    const auto ksize = phi::vectorize(filter_data_dims);
140

141 142
    UpdatePaddingAndDilation(
        &paddings, &dilations, padding_algorithm, data_dims, strides, ksize);
143

144 145 146 147
    std::transform(
        dilations.begin(), dilations.end(), dilations.begin(), [](int64_t i) {
          return i - 1;
        });
148

149
    const auto src_tz = phi::vectorize(input->dims());
150
    const auto weights_tz = GetWeightsTz(filter, groups);
151
    const auto dst_tz = phi::vectorize(output->dims());
152 153 154 155 156 157 158 159
    const auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);

    /* create memory descriptor for convolution without specified format
     * ('any') which lets a primitive (convolution in this case) choose
     * the memory format preferred for best performance
     */
    const auto chosen_memory_format = MKLDNNMemoryFormat::any;

160
    auto data_type = dnnl::memory::data_type::f32;
161 162
    if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" ||
        std::is_same<T_out, platform::bfloat16>::value)
163
      data_type = dnnl::memory::data_type::bf16;
164 165 166 167 168 169 170 171

    const auto src_md =
        platform::MKLDNNMemDesc(src_tz, data_type, chosen_memory_format);
    const auto weights_md =
        platform::MKLDNNMemDesc(weights_tz, data_type, chosen_memory_format);
    const auto dst_md = platform::MKLDNNMemDesc(
        dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);

172
    const dnnl::primitive_attr conv_trans_attr = CreateConvAttrs(ctx);
173 174
    auto fwd_prop_kind = is_test_ ? dnnl::prop_kind::forward_inference
                                  : dnnl::prop_kind::forward_training;
175
    if (bias) {
176
      std::vector<int64_t> bias_tz = phi::vectorize(bias->dims());
177 178 179
      const auto bias_md =
          platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x);
      this->AcquireForwardPrimitiveDescriptor(
180 181 182 183 184 185 186 187 188 189 190
          conv_trans_attr,
          fwd_prop_kind,
          dnnl::algorithm::deconvolution_direct,
          src_md,
          weights_md,
          bias_md,
          dst_md,
          strides,
          dilations,
          mkldnn_paddings[0],
          mkldnn_paddings[1]);
191 192
    } else {
      this->AcquireForwardPrimitiveDescriptor(
193 194 195 196 197 198 199 200 201
          conv_trans_attr,
          fwd_prop_kind,
          dnnl::algorithm::deconvolution_direct,
          src_md,
          weights_md,
          dst_md,
          strides,
          dilations,
          mkldnn_paddings[0],
202
          mkldnn_paddings[1]);
203 204
    }
  }
J
Jacek Czaja 已提交
205

206
  dnnl::primitive_attr CreateConvAttrs(const framework::ExecutionContext& ctx) {
207 208
    dnnl::primitive_attr conv_attr;
    dnnl::post_ops post_operations;
209

210 211 212 213 214
    const std::string fuse_activation =
        ctx.Attr<std::string>("fuse_activation");
    const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
    const float fuse_beta = ctx.Attr<float>("fuse_beta");

215 216 217 218
    // Fusion with ReLU layer is executed through the PostOps feature. Create a
    // PostOps object and configure it to execute an eltwise relu operation.
    if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
      constexpr float scale = 1.0f;
219 220
      post_operations.append_eltwise(
          scale, dnnl::algorithm::eltwise_relu, fuse_alpha, fuse_beta);
221 222
    } else if (fuse_activation == "relu6") {
      constexpr float scale = 1.0f;
223 224
      post_operations.append_eltwise(
          scale, dnnl::algorithm::eltwise_bounded_relu, fuse_alpha, fuse_beta);
225 226
    } else if (fuse_activation == "swish") {
      constexpr float scale = 1.0f;
227 228
      post_operations.append_eltwise(
          scale, dnnl::algorithm::eltwise_swish, fuse_alpha, fuse_beta);
229 230 231 232
    }
    conv_attr.set_post_ops(post_operations);
    return conv_attr;
  }
J
Jacek Czaja 已提交
233

234
  std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorder(
235
      const framework::Tensor* input) {
J
Jacek Czaja 已提交
236
    const T* input_data = input->data<T>();
237
    auto user_src_md = platform::MKLDNNMemDesc(phi::vectorize(input->dims()),
238 239
                                               platform::MKLDNNGetDataType<T>(),
                                               input->format());
240
    return platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward>::
241 242
        AcquireMemoryWithReorder(user_src_md,
                                 this->fwd_pd_->src_desc(),
243
                                 platform::to_void_cast<T>(input_data));
244 245
  }

246
  std::shared_ptr<dnnl::memory> AcquireWeightsMemoryWithReorder(
247 248 249 250
      const platform::MKLDNNDeviceContext& dev_ctx,
      const std::string& key,
      const framework::Tensor* filter,
      const int& groups) {
251 252 253 254 255
    const K* filter_data = filter->data<K>();
    auto weights_tz = GetWeightsTz(filter, groups);
    int g = std::max(groups, 1);

    auto user_src_md = platform::MKLDNNMemDesc(
256 257
        weights_tz,
        platform::MKLDNNGetDataType<K>(),
258
        (g == 1) ? MKLDNNMemoryFormat::iohw : MKLDNNMemoryFormat::giohw);
J
Jacek Czaja 已提交
259

260
    return this->template AcquireMemoryWithReorder<K>(
261 262 263 264 265 266
        dev_ctx,
        user_src_md,
        this->fwd_pd_->weights_desc(),
        platform::to_void_cast<K>(filter_data),
        key,
        "@weights_mem_p",
267
        is_test_);
268
  }
269

270
  template <typename F = T>
271
  std::shared_ptr<dnnl::memory> AcquireMemoryWithReorder(
272
      const platform::MKLDNNDeviceContext& dev_ctx,
273 274 275 276 277 278 279
      const dnnl::memory::desc& user_md,
      const dnnl::memory::desc& target_md,
      void* ptr,
      const std::string& key,
      const std::string& suffix,
      bool is_persistent = false,
      const std::vector<float>& scale_data = {1.0f},
280
      int mask = 0) {
281 282 283 284 285 286 287 288 289 290 291 292
    const auto target_key = key + suffix + "_target";
    const auto key_reorder_p = key + suffix + "reorder_p";
    const auto user_key = key + suffix + "_user";

    auto target_memory_p =
        std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(target_key));

    if (target_memory_p == nullptr) {
      auto user_memory_p =
          std::make_shared<dnnl::memory>(user_md, this->engine_, ptr);
      if (user_md != target_md) {
        target_memory_p =
293
            std::make_shared<dnnl::memory>(target_md, this->engine_);
294 295 296 297
        dnnl::reorder::primitive_desc reorder_pdesc;
        if (platform::is_int8<T>()) {
          dnnl::primitive_attr attr;
          attr.set_output_scales(mask, scale_data);
298 299
          reorder_pdesc = dnnl::reorder::primitive_desc(
              *user_memory_p, *target_memory_p, attr);
300 301 302 303 304 305 306 307
        } else {
          reorder_pdesc =
              dnnl::reorder::primitive_desc(*user_memory_p, *target_memory_p);
        }
        auto reorder_p = std::make_shared<dnnl::reorder>(reorder_pdesc);
        dev_ctx.SetBlob(key_reorder_p, reorder_p);

        auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
C
chenjian 已提交
308
        platform::RecordEvent record_reorder(
309 310 311
            "int_reorder",
            platform::TracerEventType::UserDefined,
            2,
C
chenjian 已提交
312
            platform::EventRole::kUniqueOp);
313 314 315
        reorder_p->execute(
            astream,
            {{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
        astream.wait();
      } else {
        target_memory_p = user_memory_p;
      }
      dev_ctx.SetBlob(user_key, user_memory_p);
      dev_ctx.SetBlob(target_key, target_memory_p);
    } else if (!is_persistent) {
      auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();

      auto user_memory_p =
          std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(user_key));
      user_memory_p->set_data_handle(ptr);

      // TODO(jczaja): Here we detect if reorder is cached it means it is needed
      // need to change this to get rid of keys
331
      auto reorder_p = std::static_pointer_cast<dnnl::reorder>(
332 333
          dev_ctx.GetBlob(key_reorder_p));
      if (reorder_p != nullptr) {
C
chenjian 已提交
334
        platform::RecordEvent record_reorder(
335 336 337
            "int_reorder",
            platform::TracerEventType::UserDefined,
            2,
C
chenjian 已提交
338
            platform::EventRole::kUniqueOp);
339 340 341
        reorder_p->execute(
            astream,
            {{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
342 343
        astream.wait();
      }
J
Jacek Czaja 已提交
344
    }
345
    return target_memory_p;
346 347
  }

348
  std::shared_ptr<dnnl::memory> AcquireBiasMemoryWithReorder(
349 350
      const platform::MKLDNNDeviceContext& dev_ctx,
      const std::string& key,
351 352
      const framework::Tensor* bias) {
    const K* bias_data = bias->data<K>();
353 354 355 356 357 358 359 360 361 362 363
    auto user_bias_md =
        platform::MKLDNNMemDesc(phi::vectorize(bias->dims()),
                                platform::MKLDNNGetDataType<K>(),
                                MKLDNNMemoryFormat::x);
    return this->AcquireMemoryWithReorder(dev_ctx,
                                          user_bias_md,
                                          this->fwd_pd_->bias_desc(),
                                          platform::to_void_cast<K>(bias_data),
                                          key,
                                          "@bias_mem_p",
                                          is_test_);
364
  }
365 366 367

 private:
  const bool is_test_;
368
};
J
Jacek Czaja 已提交
369

370 371 372 373
template <typename T, typename K>
class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
374 375
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
                      true,
376 377 378 379 380 381 382 383 384 385 386 387 388 389
                      platform::errors::PreconditionNotMet(
                          "Operator DNNL ConvTranspose must use CPUPlace"));
    const bool is_bfloat16 =
        ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16";
    const bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
    if (is_bfloat16) {
      if (force_fp32_output)
        Execute<float>(ctx);
      else
        Execute<platform::bfloat16>(ctx);
    } else {
      Execute<float>(ctx);
    }
  }
J
Jacek Czaja 已提交
390

391 392 393 394 395
  template <typename T_out>
  void Execute(const framework::ExecutionContext& ctx) const {
    auto& dev_ctx =
        ctx.template device_context<platform::MKLDNNDeviceContext>();
    const auto& mkldnn_engine = dev_ctx.GetEngine();
J
Jacek Czaja 已提交
396

397 398 399 400 401
    const auto* input = ctx.Input<Tensor>("Input");
    const auto* filter = ctx.Input<Tensor>("Filter");
    const auto* bias =
        ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
    auto* output = ctx.Output<Tensor>("Output");
402 403
    ConvTransposeMKLDNNHandlerT<T, K, T_out> handler(
        ctx, mkldnn_engine, input, filter, bias, output);
404
    auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
405
    // Caching Key for weights is needed
406 407
    std::string key = platform::CreateKey(dev_ctx,
                                          ctx.InputName("Input"),
408 409 410
                                          ctx.InputName("Filter"),
                                          (bias ? ctx.InputName("Bias") : ""));
    key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
411
    auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
412
        dev_ctx, key, filter, ctx.Attr<int>("groups"));
413 414 415 416 417 418

    std::shared_ptr<dnnl::memory> dst_memory_p =
        handler.template AcquireDstMemory<T_out>(output);
    auto conv_p = handler.AcquireForwardPrimitive();

    std::unordered_map<int, dnnl::memory> args = {
419 420 421
        {DNNL_ARG_SRC, *src_memory_p},
        {DNNL_ARG_WEIGHTS, *weights_memory_p},
        {DNNL_ARG_DST, *dst_memory_p}};
A
Adam 已提交
422

J
Jacek Czaja 已提交
423
    if (bias) {
424 425
      auto bias_memory_p =
          handler.AcquireBiasMemoryWithReorder(dev_ctx, key, bias);
426
      args.insert({DNNL_ARG_BIAS, *bias_memory_p});
J
Jacek Czaja 已提交
427
    }
428 429
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
    conv_p->execute(astream, args);
A
Adam 已提交
430
    astream.wait();
431 432
    output->set_layout(DataLayout::kMKLDNN);
    output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
J
Jacek Czaja 已提交
433 434 435 436 437 438 439 440
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

441
REGISTER_OP_KERNEL(
442 443 444
    conv2d_transpose,
    MKLDNN,
    ::paddle::platform::CPUPlace,
445 446
    ops::ConvTransposeMKLDNNOpKernel<float, float>,
    ops::ConvTransposeMKLDNNOpKernel<paddle::platform::bfloat16, float>);