conv_transpose_mkldnn_op.cc 17.1 KB
Newer Older
J
Jacek Czaja 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
18
#include "paddle/fluid/operators/conv_op.h"
J
Jacek Czaja 已提交
19 20 21 22 23 24 25 26
#include "paddle/fluid/platform/mkldnn_reuse.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using framework::DataLayout;

27
inline dnnl::memory::dims GetWeightsTz(const Tensor* filter, const int groups) {
28
  auto weights_tz = phi::vectorize(filter->dims());
29
  int g = std::max(groups, 1);
30
  int g_dim = (g > 1) ? 1 : 0;
31
  platform::GetGroupConvWeightsTz(weights_tz, g);
32 33
  // gIOHW -> gOIHW || IOHW -> OIHW
  std::swap(weights_tz[g_dim + 0], weights_tz[g_dim + 1]);
34 35 36 37 38
  return weights_tz;
}

template <typename T, typename K, typename T_out>
class ConvTransposeMKLDNNHandlerT
39
    : public platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward> {
J
Jacek Czaja 已提交
40
 public:
41
  ConvTransposeMKLDNNHandlerT(const framework::ExecutionContext& ctx,
42
                              const dnnl::engine mkldnn_engine,
43 44 45 46
                              const Tensor* input,
                              const Tensor* filter,
                              const Tensor* bias,
                              Tensor* output)
47
      : platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward>(
48 49
            mkldnn_engine, ctx.GetPlace()),
        is_test_(ctx.Attr<bool>("is_test")) {
50 51
    PADDLE_ENFORCE_EQ(is_test_,
                      true,
52 53 54 55 56 57
                      platform::errors::InvalidArgument(
                          "ConvTransposeMKLDNN works only for inference. "
                          "The attribute \'is_test\' value should be set to "
                          "True, but got is_test=False."));

    PADDLE_ENFORCE_EQ(
58 59
        input->layout(),
        DataLayout::kMKLDNN,
60 61
        platform::errors::InvalidArgument(
            "Got wrong layout = %d for Input tensor.", input->layout()));
62 63
    PADDLE_ENFORCE_NE(input->format(),
                      MKLDNNMemoryFormat::undef,
64 65 66 67 68
                      platform::errors::InvalidArgument(
                          "Got wrong format for Input tensor. The input "
                          "format is undefined."));

    PADDLE_ENFORCE_EQ(
69 70
        filter->layout(),
        DataLayout::kMKLDNN,
71
        platform::errors::InvalidArgument(
72
            "The filter tensor's layout should be %d, but got %d.",
73 74 75 76
            DataLayout::kMKLDNN,
            filter->layout()));
    PADDLE_ENFORCE_NE(filter->format(),
                      MKLDNNMemoryFormat::undef,
77 78 79 80
                      platform::errors::InvalidArgument(
                          "Got wrong formats for Filter tensor."));

    PADDLE_ENFORCE_EQ(
81 82
        input->dims().size(),
        4,
83 84 85 86
        platform::errors::InvalidArgument("Input must be with 4 dimensions, "
                                          "i.e. NCHW. but got dimension =%d",
                                          input->dims().size()));
    PADDLE_ENFORCE_EQ(
87 88
        filter->dims().size(),
        4,
89 90 91
        platform::errors::InvalidArgument("Filter must be with 4 dimensions, "
                                          "i.e. OIHW, but got dimension =%d",
                                          filter->dims().size()));
F
FDInSky 已提交
92

93
    if (bias) {
F
FDInSky 已提交
94
      PADDLE_ENFORCE_EQ(
95 96
          bias->layout(),
          DataLayout::kMKLDNN,
97
          platform::errors::InvalidArgument(
98
              "The bias tensor's laytout should be %d, but got %d.",
99 100 101 102
              DataLayout::kMKLDNN,
              bias->layout()));
      PADDLE_ENFORCE_NE(bias->format(),
                        MKLDNNMemoryFormat::undef,
103
                        platform::errors::InvalidArgument(
104
                            "Got wrong format for Bias tensor."));
A
Adam 已提交
105

106
      PADDLE_ENFORCE_EQ(
107 108
          bias->dims().size(),
          1,
109 110 111 112
          platform::errors::InvalidArgument("Bias must only have 1 dimension, "
                                            "i.e. X, but got dimension = %d .",
                                            bias->dims().size()));
    }
113

114
    std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
115
    dnnl::memory::dims strides(begin(strides_temp), end(strides_temp));
116 117

    std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
118
    dnnl::memory::dims paddings(begin(paddings_temp), end(paddings_temp));
119 120

    std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
121
    dnnl::memory::dims dilations(begin(dilations_temp), end(dilations_temp));
122 123 124 125 126

    int groups = ctx.Attr<int>("groups");
    std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");

    PADDLE_ENFORCE_EQ(
127 128
        strides.size(),
        2,
129 130 131 132
        platform::errors::Unimplemented(
            "Now we only support 2d oneDNN convolution transpose op"));

    const auto& input_dims = input->dims();
133
    const auto data_dims = phi::slice_ddim(input_dims, 2, input_dims.size());
134 135
    const auto& filter_dims = filter->dims();
    const auto filter_data_dims =
136
        phi::slice_ddim(filter_dims, 2, filter_dims.size());
137

138
    const auto ksize = phi::vectorize(filter_data_dims);
139

140 141
    UpdatePaddingAndDilation(
        &paddings, &dilations, padding_algorithm, data_dims, strides, ksize);
142

143 144 145 146
    std::transform(
        dilations.begin(), dilations.end(), dilations.begin(), [](int64_t i) {
          return i - 1;
        });
147

148
    const auto src_tz = phi::vectorize(input->dims());
149
    const auto weights_tz = GetWeightsTz(filter, groups);
150
    const auto dst_tz = phi::vectorize(output->dims());
151 152 153 154 155 156 157 158
    const auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);

    /* create memory descriptor for convolution without specified format
     * ('any') which lets a primitive (convolution in this case) choose
     * the memory format preferred for best performance
     */
    const auto chosen_memory_format = MKLDNNMemoryFormat::any;

159
    auto data_type = dnnl::memory::data_type::f32;
160 161
    if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" ||
        std::is_same<T_out, platform::bfloat16>::value)
162
      data_type = dnnl::memory::data_type::bf16;
163 164 165 166 167 168 169 170

    const auto src_md =
        platform::MKLDNNMemDesc(src_tz, data_type, chosen_memory_format);
    const auto weights_md =
        platform::MKLDNNMemDesc(weights_tz, data_type, chosen_memory_format);
    const auto dst_md = platform::MKLDNNMemDesc(
        dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);

171
    const dnnl::primitive_attr conv_trans_attr = CreateConvAttrs(ctx);
172 173
    auto fwd_prop_kind = is_test_ ? dnnl::prop_kind::forward_inference
                                  : dnnl::prop_kind::forward_training;
174
    if (bias) {
175
      std::vector<int64_t> bias_tz = phi::vectorize(bias->dims());
176 177 178
      const auto bias_md =
          platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x);
      this->AcquireForwardPrimitiveDescriptor(
179 180 181 182 183 184 185 186 187 188 189
          conv_trans_attr,
          fwd_prop_kind,
          dnnl::algorithm::deconvolution_direct,
          src_md,
          weights_md,
          bias_md,
          dst_md,
          strides,
          dilations,
          mkldnn_paddings[0],
          mkldnn_paddings[1]);
190 191
    } else {
      this->AcquireForwardPrimitiveDescriptor(
192 193 194 195 196 197 198 199 200
          conv_trans_attr,
          fwd_prop_kind,
          dnnl::algorithm::deconvolution_direct,
          src_md,
          weights_md,
          dst_md,
          strides,
          dilations,
          mkldnn_paddings[0],
201
          mkldnn_paddings[1]);
202 203
    }
  }
J
Jacek Czaja 已提交
204

205
  dnnl::primitive_attr CreateConvAttrs(const framework::ExecutionContext& ctx) {
206 207
    dnnl::primitive_attr conv_attr;
    dnnl::post_ops post_operations;
208

209 210 211 212 213
    const std::string fuse_activation =
        ctx.Attr<std::string>("fuse_activation");
    const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
    const float fuse_beta = ctx.Attr<float>("fuse_beta");

214 215 216 217
    // Fusion with ReLU layer is executed through the PostOps feature. Create a
    // PostOps object and configure it to execute an eltwise relu operation.
    if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
      constexpr float scale = 1.0f;
218 219
      post_operations.append_eltwise(
          scale, dnnl::algorithm::eltwise_relu, fuse_alpha, fuse_beta);
220 221
    } else if (fuse_activation == "relu6") {
      constexpr float scale = 1.0f;
222 223
      post_operations.append_eltwise(
          scale, dnnl::algorithm::eltwise_bounded_relu, fuse_alpha, fuse_beta);
224 225
    } else if (fuse_activation == "swish") {
      constexpr float scale = 1.0f;
226 227
      post_operations.append_eltwise(
          scale, dnnl::algorithm::eltwise_swish, fuse_alpha, fuse_beta);
228 229 230 231
    }
    conv_attr.set_post_ops(post_operations);
    return conv_attr;
  }
J
Jacek Czaja 已提交
232

233
  std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorder(
234
      const framework::Tensor* input) {
J
Jacek Czaja 已提交
235
    const T* input_data = input->data<T>();
236
    auto user_src_md = platform::MKLDNNMemDesc(phi::vectorize(input->dims()),
237 238
                                               platform::MKLDNNGetDataType<T>(),
                                               input->format());
239
    return platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward>::
240 241
        AcquireMemoryWithReorder(user_src_md,
                                 this->fwd_pd_->src_desc(),
242
                                 platform::to_void_cast<T>(input_data));
243 244
  }

245
  std::shared_ptr<dnnl::memory> AcquireWeightsMemoryWithReorder(
246 247 248 249
      const platform::MKLDNNDeviceContext& dev_ctx,
      const std::string& key,
      const framework::Tensor* filter,
      const int& groups) {
250 251 252 253 254
    const K* filter_data = filter->data<K>();
    auto weights_tz = GetWeightsTz(filter, groups);
    int g = std::max(groups, 1);

    auto user_src_md = platform::MKLDNNMemDesc(
255 256
        weights_tz,
        platform::MKLDNNGetDataType<K>(),
257
        (g == 1) ? MKLDNNMemoryFormat::iohw : MKLDNNMemoryFormat::giohw);
J
Jacek Czaja 已提交
258

259
    return this->template AcquireMemoryWithReorder<K>(
260 261 262 263 264 265
        dev_ctx,
        user_src_md,
        this->fwd_pd_->weights_desc(),
        platform::to_void_cast<K>(filter_data),
        key,
        "@weights_mem_p",
266
        is_test_);
267
  }
268

269
  template <typename F = T>
270
  std::shared_ptr<dnnl::memory> AcquireMemoryWithReorder(
271
      const platform::MKLDNNDeviceContext& dev_ctx,
272 273 274 275 276 277 278
      const dnnl::memory::desc& user_md,
      const dnnl::memory::desc& target_md,
      void* ptr,
      const std::string& key,
      const std::string& suffix,
      bool is_persistent = false,
      const std::vector<float>& scale_data = {1.0f},
279
      int mask = 0) {
280 281 282 283 284 285 286 287 288 289 290 291
    const auto target_key = key + suffix + "_target";
    const auto key_reorder_p = key + suffix + "reorder_p";
    const auto user_key = key + suffix + "_user";

    auto target_memory_p =
        std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(target_key));

    if (target_memory_p == nullptr) {
      auto user_memory_p =
          std::make_shared<dnnl::memory>(user_md, this->engine_, ptr);
      if (user_md != target_md) {
        target_memory_p =
292
            std::make_shared<dnnl::memory>(target_md, this->engine_);
293 294 295 296
        dnnl::reorder::primitive_desc reorder_pdesc;
        if (platform::is_int8<T>()) {
          dnnl::primitive_attr attr;
          attr.set_output_scales(mask, scale_data);
297 298
          reorder_pdesc = dnnl::reorder::primitive_desc(
              *user_memory_p, *target_memory_p, attr);
299 300 301 302 303 304 305 306
        } else {
          reorder_pdesc =
              dnnl::reorder::primitive_desc(*user_memory_p, *target_memory_p);
        }
        auto reorder_p = std::make_shared<dnnl::reorder>(reorder_pdesc);
        dev_ctx.SetBlob(key_reorder_p, reorder_p);

        auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
C
chenjian 已提交
307
        platform::RecordEvent record_reorder(
308 309 310
            "int_reorder",
            platform::TracerEventType::UserDefined,
            2,
C
chenjian 已提交
311
            platform::EventRole::kUniqueOp);
312 313 314
        reorder_p->execute(
            astream,
            {{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
        astream.wait();
      } else {
        target_memory_p = user_memory_p;
      }
      dev_ctx.SetBlob(user_key, user_memory_p);
      dev_ctx.SetBlob(target_key, target_memory_p);
    } else if (!is_persistent) {
      auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();

      auto user_memory_p =
          std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(user_key));
      user_memory_p->set_data_handle(ptr);

      // TODO(jczaja): Here we detect if reorder is cached it means it is needed
      // need to change this to get rid of keys
330
      auto reorder_p = std::static_pointer_cast<dnnl::reorder>(
331 332
          dev_ctx.GetBlob(key_reorder_p));
      if (reorder_p != nullptr) {
C
chenjian 已提交
333
        platform::RecordEvent record_reorder(
334 335 336
            "int_reorder",
            platform::TracerEventType::UserDefined,
            2,
C
chenjian 已提交
337
            platform::EventRole::kUniqueOp);
338 339 340
        reorder_p->execute(
            astream,
            {{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
341 342
        astream.wait();
      }
J
Jacek Czaja 已提交
343
    }
344
    return target_memory_p;
345 346
  }

347
  std::shared_ptr<dnnl::memory> AcquireBiasMemoryWithReorder(
348 349
      const platform::MKLDNNDeviceContext& dev_ctx,
      const std::string& key,
350 351
      const framework::Tensor* bias) {
    const K* bias_data = bias->data<K>();
352 353 354 355 356 357 358 359 360 361 362
    auto user_bias_md =
        platform::MKLDNNMemDesc(phi::vectorize(bias->dims()),
                                platform::MKLDNNGetDataType<K>(),
                                MKLDNNMemoryFormat::x);
    return this->AcquireMemoryWithReorder(dev_ctx,
                                          user_bias_md,
                                          this->fwd_pd_->bias_desc(),
                                          platform::to_void_cast<K>(bias_data),
                                          key,
                                          "@bias_mem_p",
                                          is_test_);
363
  }
364 365 366

 private:
  const bool is_test_;
367
};
J
Jacek Czaja 已提交
368

369 370 371 372
template <typename T, typename K>
class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
373 374
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
                      true,
375 376 377 378 379 380 381 382 383 384 385 386 387 388
                      platform::errors::PreconditionNotMet(
                          "Operator DNNL ConvTranspose must use CPUPlace"));
    const bool is_bfloat16 =
        ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16";
    const bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
    if (is_bfloat16) {
      if (force_fp32_output)
        Execute<float>(ctx);
      else
        Execute<platform::bfloat16>(ctx);
    } else {
      Execute<float>(ctx);
    }
  }
J
Jacek Czaja 已提交
389

390 391 392 393 394
  template <typename T_out>
  void Execute(const framework::ExecutionContext& ctx) const {
    auto& dev_ctx =
        ctx.template device_context<platform::MKLDNNDeviceContext>();
    const auto& mkldnn_engine = dev_ctx.GetEngine();
J
Jacek Czaja 已提交
395

396 397 398 399 400
    const auto* input = ctx.Input<Tensor>("Input");
    const auto* filter = ctx.Input<Tensor>("Filter");
    const auto* bias =
        ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
    auto* output = ctx.Output<Tensor>("Output");
401 402
    ConvTransposeMKLDNNHandlerT<T, K, T_out> handler(
        ctx, mkldnn_engine, input, filter, bias, output);
403
    auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
404
    // Caching Key for weights is needed
405 406
    std::string key = platform::CreateKey(dev_ctx,
                                          ctx.InputName("Input"),
407 408 409
                                          ctx.InputName("Filter"),
                                          (bias ? ctx.InputName("Bias") : ""));
    key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
410
    auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
411
        dev_ctx, key, filter, ctx.Attr<int>("groups"));
412 413 414 415 416 417

    std::shared_ptr<dnnl::memory> dst_memory_p =
        handler.template AcquireDstMemory<T_out>(output);
    auto conv_p = handler.AcquireForwardPrimitive();

    std::unordered_map<int, dnnl::memory> args = {
418 419 420
        {DNNL_ARG_SRC, *src_memory_p},
        {DNNL_ARG_WEIGHTS, *weights_memory_p},
        {DNNL_ARG_DST, *dst_memory_p}};
A
Adam 已提交
421

J
Jacek Czaja 已提交
422
    if (bias) {
423 424
      auto bias_memory_p =
          handler.AcquireBiasMemoryWithReorder(dev_ctx, key, bias);
425
      args.insert({DNNL_ARG_BIAS, *bias_memory_p});
J
Jacek Czaja 已提交
426
    }
427 428
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
    conv_p->execute(astream, args);
A
Adam 已提交
429
    astream.wait();
430 431
    output->set_layout(DataLayout::kMKLDNN);
    output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
J
Jacek Czaja 已提交
432 433 434 435 436 437 438 439
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

440
REGISTER_OP_KERNEL(
441 442 443
    conv2d_transpose,
    MKLDNN,
    ::paddle::platform::CPUPlace,
444 445
    ops::ConvTransposeMKLDNNOpKernel<float, float>,
    ops::ConvTransposeMKLDNNOpKernel<paddle::platform::bfloat16, float>);