conv_transpose_mkldnn_op.cc 16.2 KB
Newer Older
J
Jacek Czaja 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
18
#include "paddle/fluid/operators/conv_op.h"
J
Jacek Czaja 已提交
19 20 21 22 23 24 25 26
#include "paddle/fluid/platform/mkldnn_reuse.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using framework::DataLayout;

27
inline dnnl::memory::dims GetWeightsTz(const Tensor* filter, const int groups) {
28
  auto weights_tz = phi::vectorize(filter->dims());
29
  int g = std::max(groups, 1);
30
  int g_dim = (g > 1) ? 1 : 0;
31
  platform::GetGroupConvWeightsTz(weights_tz, g);
32 33
  // gIOHW -> gOIHW || IOHW -> OIHW
  std::swap(weights_tz[g_dim + 0], weights_tz[g_dim + 1]);
34 35 36 37 38
  return weights_tz;
}

template <typename T, typename K, typename T_out>
class ConvTransposeMKLDNNHandlerT
39
    : public platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward> {
J
Jacek Czaja 已提交
40
 public:
41
  ConvTransposeMKLDNNHandlerT(const framework::ExecutionContext& ctx,
42
                              const dnnl::engine mkldnn_engine,
43 44 45 46
                              const Tensor* input,
                              const Tensor* filter,
                              const Tensor* bias,
                              Tensor* output)
47
      : platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward>(
48 49
            mkldnn_engine, ctx.GetPlace()),
        is_test_(ctx.Attr<bool>("is_test")) {
50 51
    PADDLE_ENFORCE_EQ(is_test_,
                      true,
52 53 54 55 56 57
                      platform::errors::InvalidArgument(
                          "ConvTransposeMKLDNN works only for inference. "
                          "The attribute \'is_test\' value should be set to "
                          "True, but got is_test=False."));

    PADDLE_ENFORCE_EQ(
58 59
        input->layout(),
        DataLayout::kMKLDNN,
60 61 62 63
        platform::errors::InvalidArgument(
            "Got wrong layout = %d for Input tensor.", input->layout()));

    PADDLE_ENFORCE_EQ(
64 65
        filter->layout(),
        DataLayout::kMKLDNN,
66
        platform::errors::InvalidArgument(
67
            "The filter tensor's layout should be %d, but got %d.",
68 69
            DataLayout::kMKLDNN,
            filter->layout()));
70 71

    PADDLE_ENFORCE_EQ(
72 73
        input->dims().size(),
        4,
74 75 76 77
        platform::errors::InvalidArgument("Input must be with 4 dimensions, "
                                          "i.e. NCHW. but got dimension =%d",
                                          input->dims().size()));
    PADDLE_ENFORCE_EQ(
78 79
        filter->dims().size(),
        4,
80 81 82
        platform::errors::InvalidArgument("Filter must be with 4 dimensions, "
                                          "i.e. OIHW, but got dimension =%d",
                                          filter->dims().size()));
F
FDInSky 已提交
83

84
    if (bias) {
F
FDInSky 已提交
85
      PADDLE_ENFORCE_EQ(
86 87
          bias->layout(),
          DataLayout::kMKLDNN,
88
          platform::errors::InvalidArgument(
89
              "The bias tensor's laytout should be %d, but got %d.",
90 91
              DataLayout::kMKLDNN,
              bias->layout()));
A
Adam 已提交
92

93
      PADDLE_ENFORCE_EQ(
94 95
          bias->dims().size(),
          1,
96 97 98 99
          platform::errors::InvalidArgument("Bias must only have 1 dimension, "
                                            "i.e. X, but got dimension = %d .",
                                            bias->dims().size()));
    }
100

101
    std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
102
    dnnl::memory::dims strides(begin(strides_temp), end(strides_temp));
103 104

    std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
105
    dnnl::memory::dims paddings(begin(paddings_temp), end(paddings_temp));
106 107

    std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
108
    dnnl::memory::dims dilations(begin(dilations_temp), end(dilations_temp));
109 110 111 112 113

    int groups = ctx.Attr<int>("groups");
    std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");

    PADDLE_ENFORCE_EQ(
114 115
        strides.size(),
        2,
116 117 118 119
        platform::errors::Unimplemented(
            "Now we only support 2d oneDNN convolution transpose op"));

    const auto& input_dims = input->dims();
120
    const auto data_dims = phi::slice_ddim(input_dims, 2, input_dims.size());
121 122
    const auto& filter_dims = filter->dims();
    const auto filter_data_dims =
123
        phi::slice_ddim(filter_dims, 2, filter_dims.size());
124

125
    const auto ksize = phi::vectorize(filter_data_dims);
126

127 128
    UpdatePaddingAndDilation(
        &paddings, &dilations, padding_algorithm, data_dims, strides, ksize);
129

130 131 132 133
    std::transform(
        dilations.begin(), dilations.end(), dilations.begin(), [](int64_t i) {
          return i - 1;
        });
134

135
    const auto src_tz = phi::vectorize(input->dims());
136
    const auto weights_tz = GetWeightsTz(filter, groups);
137
    const auto dst_tz = phi::vectorize(output->dims());
138 139 140 141 142 143 144 145
    const auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);

    /* create memory descriptor for convolution without specified format
     * ('any') which lets a primitive (convolution in this case) choose
     * the memory format preferred for best performance
     */
    const auto chosen_memory_format = MKLDNNMemoryFormat::any;

146
    auto data_type = dnnl::memory::data_type::f32;
147 148
    if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" ||
        std::is_same<T_out, platform::bfloat16>::value)
149
      data_type = dnnl::memory::data_type::bf16;
150 151 152 153 154 155 156 157

    const auto src_md =
        platform::MKLDNNMemDesc(src_tz, data_type, chosen_memory_format);
    const auto weights_md =
        platform::MKLDNNMemDesc(weights_tz, data_type, chosen_memory_format);
    const auto dst_md = platform::MKLDNNMemDesc(
        dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);

158
    const dnnl::primitive_attr conv_trans_attr = CreateConvAttrs(ctx);
159 160
    auto fwd_prop_kind = is_test_ ? dnnl::prop_kind::forward_inference
                                  : dnnl::prop_kind::forward_training;
161
    if (bias) {
162
      std::vector<int64_t> bias_tz = phi::vectorize(bias->dims());
163 164 165
      const auto bias_md =
          platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x);
      this->AcquireForwardPrimitiveDescriptor(
166 167 168 169 170 171 172 173 174 175 176
          conv_trans_attr,
          fwd_prop_kind,
          dnnl::algorithm::deconvolution_direct,
          src_md,
          weights_md,
          bias_md,
          dst_md,
          strides,
          dilations,
          mkldnn_paddings[0],
          mkldnn_paddings[1]);
177 178
    } else {
      this->AcquireForwardPrimitiveDescriptor(
179 180 181 182 183 184 185 186 187
          conv_trans_attr,
          fwd_prop_kind,
          dnnl::algorithm::deconvolution_direct,
          src_md,
          weights_md,
          dst_md,
          strides,
          dilations,
          mkldnn_paddings[0],
188
          mkldnn_paddings[1]);
189 190
    }
  }
J
Jacek Czaja 已提交
191

192
  dnnl::primitive_attr CreateConvAttrs(const framework::ExecutionContext& ctx) {
193 194
    dnnl::primitive_attr conv_attr;
    dnnl::post_ops post_operations;
195

196 197 198 199 200
    const std::string fuse_activation =
        ctx.Attr<std::string>("fuse_activation");
    const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
    const float fuse_beta = ctx.Attr<float>("fuse_beta");

201 202 203 204
    // Fusion with ReLU layer is executed through the PostOps feature. Create a
    // PostOps object and configure it to execute an eltwise relu operation.
    if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
      constexpr float scale = 1.0f;
205 206
      post_operations.append_eltwise(
          scale, dnnl::algorithm::eltwise_relu, fuse_alpha, fuse_beta);
207 208
    } else if (fuse_activation == "relu6") {
      constexpr float scale = 1.0f;
209 210
      post_operations.append_eltwise(
          scale, dnnl::algorithm::eltwise_bounded_relu, fuse_alpha, fuse_beta);
211 212
    } else if (fuse_activation == "swish") {
      constexpr float scale = 1.0f;
213 214
      post_operations.append_eltwise(
          scale, dnnl::algorithm::eltwise_swish, fuse_alpha, fuse_beta);
215 216 217 218
    }
    conv_attr.set_post_ops(post_operations);
    return conv_attr;
  }
J
Jacek Czaja 已提交
219

220
  std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorder(
221
      const framework::Tensor* input) {
J
Jacek Czaja 已提交
222
    const T* input_data = input->data<T>();
223
    return platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward>::
224
        AcquireMemoryWithReorder(input->mem_desc(),
225
                                 this->fwd_pd_->src_desc(),
226
                                 platform::to_void_cast<T>(input_data));
227 228
  }

229
  std::shared_ptr<dnnl::memory> AcquireWeightsMemoryWithReorder(
230 231 232 233
      const platform::MKLDNNDeviceContext& dev_ctx,
      const std::string& key,
      const framework::Tensor* filter,
      const int& groups) {
234 235 236 237 238
    const K* filter_data = filter->data<K>();
    auto weights_tz = GetWeightsTz(filter, groups);
    int g = std::max(groups, 1);

    auto user_src_md = platform::MKLDNNMemDesc(
239 240
        weights_tz,
        platform::MKLDNNGetDataType<K>(),
241
        (g == 1) ? MKLDNNMemoryFormat::iohw : MKLDNNMemoryFormat::giohw);
J
Jacek Czaja 已提交
242

243
    return this->template AcquireMemoryWithReorder<K>(
244 245 246 247 248 249
        dev_ctx,
        user_src_md,
        this->fwd_pd_->weights_desc(),
        platform::to_void_cast<K>(filter_data),
        key,
        "@weights_mem_p",
250
        is_test_);
251
  }
252

253
  template <typename F = T>
254
  std::shared_ptr<dnnl::memory> AcquireMemoryWithReorder(
255
      const platform::MKLDNNDeviceContext& dev_ctx,
256 257 258 259 260 261 262
      const dnnl::memory::desc& user_md,
      const dnnl::memory::desc& target_md,
      void* ptr,
      const std::string& key,
      const std::string& suffix,
      bool is_persistent = false,
      const std::vector<float>& scale_data = {1.0f},
263
      int mask = 0) {
264 265 266 267 268 269 270 271 272 273 274 275
    const auto target_key = key + suffix + "_target";
    const auto key_reorder_p = key + suffix + "reorder_p";
    const auto user_key = key + suffix + "_user";

    auto target_memory_p =
        std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(target_key));

    if (target_memory_p == nullptr) {
      auto user_memory_p =
          std::make_shared<dnnl::memory>(user_md, this->engine_, ptr);
      if (user_md != target_md) {
        target_memory_p =
276
            std::make_shared<dnnl::memory>(target_md, this->engine_);
277 278 279 280
        dnnl::reorder::primitive_desc reorder_pdesc;
        if (platform::is_int8<T>()) {
          dnnl::primitive_attr attr;
          attr.set_output_scales(mask, scale_data);
281 282
          reorder_pdesc = dnnl::reorder::primitive_desc(
              *user_memory_p, *target_memory_p, attr);
283 284 285 286 287 288 289 290
        } else {
          reorder_pdesc =
              dnnl::reorder::primitive_desc(*user_memory_p, *target_memory_p);
        }
        auto reorder_p = std::make_shared<dnnl::reorder>(reorder_pdesc);
        dev_ctx.SetBlob(key_reorder_p, reorder_p);

        auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
C
chenjian 已提交
291
        platform::RecordEvent record_reorder(
292 293 294
            "int_reorder",
            platform::TracerEventType::UserDefined,
            2,
C
chenjian 已提交
295
            platform::EventRole::kUniqueOp);
296 297 298
        reorder_p->execute(
            astream,
            {{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
        astream.wait();
      } else {
        target_memory_p = user_memory_p;
      }
      dev_ctx.SetBlob(user_key, user_memory_p);
      dev_ctx.SetBlob(target_key, target_memory_p);
    } else if (!is_persistent) {
      auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();

      auto user_memory_p =
          std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(user_key));
      user_memory_p->set_data_handle(ptr);

      // TODO(jczaja): Here we detect if reorder is cached it means it is needed
      // need to change this to get rid of keys
314
      auto reorder_p = std::static_pointer_cast<dnnl::reorder>(
315 316
          dev_ctx.GetBlob(key_reorder_p));
      if (reorder_p != nullptr) {
C
chenjian 已提交
317
        platform::RecordEvent record_reorder(
318 319 320
            "int_reorder",
            platform::TracerEventType::UserDefined,
            2,
C
chenjian 已提交
321
            platform::EventRole::kUniqueOp);
322 323 324
        reorder_p->execute(
            astream,
            {{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
325 326
        astream.wait();
      }
J
Jacek Czaja 已提交
327
    }
328
    return target_memory_p;
329 330
  }

331
  std::shared_ptr<dnnl::memory> AcquireBiasMemoryWithReorder(
332 333
      const platform::MKLDNNDeviceContext& dev_ctx,
      const std::string& key,
334 335
      const framework::Tensor* bias) {
    const K* bias_data = bias->data<K>();
336 337 338 339 340 341 342 343 344 345 346
    auto user_bias_md =
        platform::MKLDNNMemDesc(phi::vectorize(bias->dims()),
                                platform::MKLDNNGetDataType<K>(),
                                MKLDNNMemoryFormat::x);
    return this->AcquireMemoryWithReorder(dev_ctx,
                                          user_bias_md,
                                          this->fwd_pd_->bias_desc(),
                                          platform::to_void_cast<K>(bias_data),
                                          key,
                                          "@bias_mem_p",
                                          is_test_);
347
  }
348 349 350

 private:
  const bool is_test_;
351
};
J
Jacek Czaja 已提交
352

353 354 355 356
template <typename T, typename K>
class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
357 358
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
                      true,
359 360 361 362 363 364 365 366 367 368 369 370 371 372
                      platform::errors::PreconditionNotMet(
                          "Operator DNNL ConvTranspose must use CPUPlace"));
    const bool is_bfloat16 =
        ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16";
    const bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
    if (is_bfloat16) {
      if (force_fp32_output)
        Execute<float>(ctx);
      else
        Execute<platform::bfloat16>(ctx);
    } else {
      Execute<float>(ctx);
    }
  }
J
Jacek Czaja 已提交
373

374 375 376 377 378
  template <typename T_out>
  void Execute(const framework::ExecutionContext& ctx) const {
    auto& dev_ctx =
        ctx.template device_context<platform::MKLDNNDeviceContext>();
    const auto& mkldnn_engine = dev_ctx.GetEngine();
J
Jacek Czaja 已提交
379

380 381 382 383 384
    const auto* input = ctx.Input<Tensor>("Input");
    const auto* filter = ctx.Input<Tensor>("Filter");
    const auto* bias =
        ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
    auto* output = ctx.Output<Tensor>("Output");
385 386
    ConvTransposeMKLDNNHandlerT<T, K, T_out> handler(
        ctx, mkldnn_engine, input, filter, bias, output);
387
    auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
388
    // Caching Key for weights is needed
389 390
    std::string key = platform::CreateKey(dev_ctx,
                                          ctx.InputName("Input"),
391 392 393
                                          ctx.InputName("Filter"),
                                          (bias ? ctx.InputName("Bias") : ""));
    key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
394
    auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
395
        dev_ctx, key, filter, ctx.Attr<int>("groups"));
396 397 398 399 400 401

    std::shared_ptr<dnnl::memory> dst_memory_p =
        handler.template AcquireDstMemory<T_out>(output);
    auto conv_p = handler.AcquireForwardPrimitive();

    std::unordered_map<int, dnnl::memory> args = {
402 403 404
        {DNNL_ARG_SRC, *src_memory_p},
        {DNNL_ARG_WEIGHTS, *weights_memory_p},
        {DNNL_ARG_DST, *dst_memory_p}};
A
Adam 已提交
405

J
Jacek Czaja 已提交
406
    if (bias) {
407 408
      auto bias_memory_p =
          handler.AcquireBiasMemoryWithReorder(dev_ctx, key, bias);
409
      args.insert({DNNL_ARG_BIAS, *bias_memory_p});
J
Jacek Czaja 已提交
410
    }
411 412
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
    conv_p->execute(astream, args);
A
Adam 已提交
413
    astream.wait();
414
    output->set_mem_desc(dst_memory_p->get_desc());
J
Jacek Czaja 已提交
415 416 417 418 419 420 421 422
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

423
REGISTER_OP_KERNEL(
424 425 426
    conv2d_transpose,
    MKLDNN,
    ::paddle::platform::CPUPlace,
427 428
    ops::ConvTransposeMKLDNNOpKernel<float, float>,
    ops::ConvTransposeMKLDNNOpKernel<paddle::platform::bfloat16, float>);