conv_transpose_mkldnn_op.cc 16.4 KB
Newer Older
J
Jacek Czaja 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

15
#include "boost/optional.hpp"
J
Jacek Czaja 已提交
16 17 18
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
19
#include "paddle/fluid/operators/conv_op.h"
J
Jacek Czaja 已提交
20 21 22 23 24 25 26 27
#include "paddle/fluid/platform/mkldnn_reuse.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using framework::DataLayout;

28
inline dnnl::memory::dims GetWeightsTz(const Tensor* filter, const int groups) {
29
  auto weights_tz = phi::vectorize(filter->dims());
30
  int g = std::max(groups, 1);
31
  int g_dim = (g > 1) ? 1 : 0;
32
  platform::GetGroupConvWeightsTz(weights_tz, g);
33 34
  // gIOHW -> gOIHW || IOHW -> OIHW
  std::swap(weights_tz[g_dim + 0], weights_tz[g_dim + 1]);
35 36 37 38 39
  return weights_tz;
}

template <typename T, typename K, typename T_out>
class ConvTransposeMKLDNNHandlerT
40
    : public platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward> {
J
Jacek Czaja 已提交
41
 public:
42
  ConvTransposeMKLDNNHandlerT(const framework::ExecutionContext& ctx,
43
                              const dnnl::engine mkldnn_engine,
44 45
                              const Tensor* input, const Tensor* filter,
                              const Tensor* bias, Tensor* output)
46
      : platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward>(
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
            mkldnn_engine, ctx.GetPlace()),
        is_test_(ctx.Attr<bool>("is_test")) {
    PADDLE_ENFORCE_EQ(is_test_, true,
                      platform::errors::InvalidArgument(
                          "ConvTransposeMKLDNN works only for inference. "
                          "The attribute \'is_test\' value should be set to "
                          "True, but got is_test=False."));

    PADDLE_ENFORCE_EQ(
        input->layout(), DataLayout::kMKLDNN,
        platform::errors::InvalidArgument(
            "Got wrong layout = %d for Input tensor.", input->layout()));
    PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
                      platform::errors::InvalidArgument(
                          "Got wrong format for Input tensor. The input "
                          "format is undefined."));

    PADDLE_ENFORCE_EQ(
        filter->layout(), DataLayout::kMKLDNN,
        platform::errors::InvalidArgument(
67
            "The filter tensor's layout should be %d, but got %d.",
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
            DataLayout::kMKLDNN, filter->layout()));
    PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
                      platform::errors::InvalidArgument(
                          "Got wrong formats for Filter tensor."));

    PADDLE_ENFORCE_EQ(
        input->dims().size(), 4,
        platform::errors::InvalidArgument("Input must be with 4 dimensions, "
                                          "i.e. NCHW. but got dimension =%d",
                                          input->dims().size()));
    PADDLE_ENFORCE_EQ(
        filter->dims().size(), 4,
        platform::errors::InvalidArgument("Filter must be with 4 dimensions, "
                                          "i.e. OIHW, but got dimension =%d",
                                          filter->dims().size()));
F
FDInSky 已提交
83

84
    if (bias) {
F
FDInSky 已提交
85
      PADDLE_ENFORCE_EQ(
86
          bias->layout(), DataLayout::kMKLDNN,
87
          platform::errors::InvalidArgument(
88 89 90
              "The bias tensor's laytout should be %d, but got %d.",
              DataLayout::kMKLDNN, bias->layout()));
      PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::undef,
91
                        platform::errors::InvalidArgument(
92
                            "Got wrong format for Bias tensor."));
A
Adam 已提交
93

94
      PADDLE_ENFORCE_EQ(
95 96 97 98 99
          bias->dims().size(), 1,
          platform::errors::InvalidArgument("Bias must only have 1 dimension, "
                                            "i.e. X, but got dimension = %d .",
                                            bias->dims().size()));
    }
100

101
    std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
102
    dnnl::memory::dims strides(begin(strides_temp), end(strides_temp));
103 104

    std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
105
    dnnl::memory::dims paddings(begin(paddings_temp), end(paddings_temp));
106 107

    std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
108
    dnnl::memory::dims dilations(begin(dilations_temp), end(dilations_temp));
109 110 111 112 113 114 115 116 117 118

    int groups = ctx.Attr<int>("groups");
    std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");

    PADDLE_ENFORCE_EQ(
        strides.size(), 2,
        platform::errors::Unimplemented(
            "Now we only support 2d oneDNN convolution transpose op"));

    const auto& input_dims = input->dims();
119
    const auto data_dims = phi::slice_ddim(input_dims, 2, input_dims.size());
120 121
    const auto& filter_dims = filter->dims();
    const auto filter_data_dims =
122
        phi::slice_ddim(filter_dims, 2, filter_dims.size());
123

124
    const auto ksize = phi::vectorize(filter_data_dims);
125 126 127 128 129 130 131

    UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
                             data_dims, strides, ksize);

    std::transform(dilations.begin(), dilations.end(), dilations.begin(),
                   [](int64_t i) { return i - 1; });

132
    const auto src_tz = phi::vectorize(input->dims());
133
    const auto weights_tz = GetWeightsTz(filter, groups);
134
    const auto dst_tz = phi::vectorize(output->dims());
135 136 137 138 139 140 141 142
    const auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);

    /* create memory descriptor for convolution without specified format
     * ('any') which lets a primitive (convolution in this case) choose
     * the memory format preferred for best performance
     */
    const auto chosen_memory_format = MKLDNNMemoryFormat::any;

143
    auto data_type = dnnl::memory::data_type::f32;
144 145
    if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" ||
        std::is_same<T_out, platform::bfloat16>::value)
146
      data_type = dnnl::memory::data_type::bf16;
147 148 149 150 151 152 153 154

    const auto src_md =
        platform::MKLDNNMemDesc(src_tz, data_type, chosen_memory_format);
    const auto weights_md =
        platform::MKLDNNMemDesc(weights_tz, data_type, chosen_memory_format);
    const auto dst_md = platform::MKLDNNMemDesc(
        dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);

155
    const dnnl::primitive_attr conv_trans_attr = CreateConvAttrs(ctx);
156 157
    auto fwd_prop_kind = is_test_ ? dnnl::prop_kind::forward_inference
                                  : dnnl::prop_kind::forward_training;
158
    if (bias) {
159
      std::vector<int64_t> bias_tz = phi::vectorize(bias->dims());
160 161 162 163 164 165 166 167 168 169 170
      const auto bias_md =
          platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x);
      this->AcquireForwardPrimitiveDescriptor(
          conv_trans_attr, fwd_prop_kind, dnnl::algorithm::deconvolution_direct,
          src_md, weights_md, bias_md, dst_md, strides, dilations,
          mkldnn_paddings[0], mkldnn_paddings[1]);
    } else {
      this->AcquireForwardPrimitiveDescriptor(
          conv_trans_attr, fwd_prop_kind, dnnl::algorithm::deconvolution_direct,
          src_md, weights_md, dst_md, strides, dilations, mkldnn_paddings[0],
          mkldnn_paddings[1]);
171 172
    }
  }
J
Jacek Czaja 已提交
173

174
  dnnl::primitive_attr CreateConvAttrs(const framework::ExecutionContext& ctx) {
175 176
    dnnl::primitive_attr conv_attr;
    dnnl::post_ops post_operations;
177

178 179 180 181 182
    const std::string fuse_activation =
        ctx.Attr<std::string>("fuse_activation");
    const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
    const float fuse_beta = ctx.Attr<float>("fuse_beta");

183 184 185 186
    // Fusion with ReLU layer is executed through the PostOps feature. Create a
    // PostOps object and configure it to execute an eltwise relu operation.
    if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
      constexpr float scale = 1.0f;
187
      post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_relu,
188 189 190
                                     fuse_alpha, fuse_beta);
    } else if (fuse_activation == "relu6") {
      constexpr float scale = 1.0f;
191 192
      post_operations.append_eltwise(
          scale, dnnl::algorithm::eltwise_bounded_relu, fuse_alpha, fuse_beta);
193 194
    } else if (fuse_activation == "swish") {
      constexpr float scale = 1.0f;
195
      post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_swish,
196 197 198 199 200
                                     fuse_alpha, fuse_beta);
    }
    conv_attr.set_post_ops(post_operations);
    return conv_attr;
  }
J
Jacek Czaja 已提交
201

202
  std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorder(
203
      const framework::Tensor* input) {
J
Jacek Czaja 已提交
204
    const T* input_data = input->data<T>();
205
    auto user_src_md = platform::MKLDNNMemDesc(phi::vectorize(input->dims()),
206 207
                                               platform::MKLDNNGetDataType<T>(),
                                               input->format());
208
    return platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward>::
209 210
        AcquireMemoryWithReorder(user_src_md, this->fwd_pd_->src_desc(),
                                 platform::to_void_cast<T>(input_data));
211 212
  }

213
  std::shared_ptr<dnnl::memory> AcquireWeightsMemoryWithReorder(
214 215 216 217 218 219 220 221
      const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key,
      const framework::Tensor* filter, const int& groups) {
    const K* filter_data = filter->data<K>();
    auto weights_tz = GetWeightsTz(filter, groups);
    int g = std::max(groups, 1);

    auto user_src_md = platform::MKLDNNMemDesc(
        weights_tz, platform::MKLDNNGetDataType<K>(),
222
        (g == 1) ? MKLDNNMemoryFormat::iohw : MKLDNNMemoryFormat::giohw);
J
Jacek Czaja 已提交
223

224 225
    return this->template AcquireMemoryWithReorder<K>(
        dev_ctx, user_src_md, this->fwd_pd_->weights_desc(),
226 227
        platform::to_void_cast<K>(filter_data), key, "@weights_mem_p",
        is_test_);
228
  }
229

230
  template <typename F = T>
231
  std::shared_ptr<dnnl::memory> AcquireMemoryWithReorder(
232
      const platform::MKLDNNDeviceContext& dev_ctx,
233 234 235 236
      const dnnl::memory::desc& user_md, const dnnl::memory::desc& target_md,
      void* ptr, const std::string& key, const std::string& suffix,
      bool is_persistent = false, const std::vector<float>& scale_data = {1.0f},
      int mask = 0) {
237 238 239 240 241 242 243 244 245 246 247 248
    const auto target_key = key + suffix + "_target";
    const auto key_reorder_p = key + suffix + "reorder_p";
    const auto user_key = key + suffix + "_user";

    auto target_memory_p =
        std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(target_key));

    if (target_memory_p == nullptr) {
      auto user_memory_p =
          std::make_shared<dnnl::memory>(user_md, this->engine_, ptr);
      if (user_md != target_md) {
        target_memory_p =
249
            std::make_shared<dnnl::memory>(target_md, this->engine_);
250 251 252 253 254 255 256 257 258 259 260 261 262 263
        dnnl::reorder::primitive_desc reorder_pdesc;
        if (platform::is_int8<T>()) {
          dnnl::primitive_attr attr;
          attr.set_output_scales(mask, scale_data);
          reorder_pdesc = dnnl::reorder::primitive_desc(*user_memory_p,
                                                        *target_memory_p, attr);
        } else {
          reorder_pdesc =
              dnnl::reorder::primitive_desc(*user_memory_p, *target_memory_p);
        }
        auto reorder_p = std::make_shared<dnnl::reorder>(reorder_pdesc);
        dev_ctx.SetBlob(key_reorder_p, reorder_p);

        auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
C
chenjian 已提交
264 265 266
        platform::RecordEvent record_reorder(
            "int_reorder", platform::TracerEventType::UserDefined, 2,
            platform::EventRole::kUniqueOp);
267 268
        reorder_p->execute(astream, {{DNNL_ARG_FROM, *user_memory_p},
                                     {DNNL_ARG_TO, *target_memory_p}});
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
        astream.wait();
      } else {
        target_memory_p = user_memory_p;
      }
      dev_ctx.SetBlob(user_key, user_memory_p);
      dev_ctx.SetBlob(target_key, target_memory_p);
    } else if (!is_persistent) {
      auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();

      auto user_memory_p =
          std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(user_key));
      user_memory_p->set_data_handle(ptr);

      // TODO(jczaja): Here we detect if reorder is cached it means it is needed
      // need to change this to get rid of keys
284
      auto reorder_p = std::static_pointer_cast<dnnl::reorder>(
285 286
          dev_ctx.GetBlob(key_reorder_p));
      if (reorder_p != nullptr) {
C
chenjian 已提交
287 288 289
        platform::RecordEvent record_reorder(
            "int_reorder", platform::TracerEventType::UserDefined, 2,
            platform::EventRole::kUniqueOp);
290 291
        reorder_p->execute(astream, {{DNNL_ARG_FROM, *user_memory_p},
                                     {DNNL_ARG_TO, *target_memory_p}});
292 293
        astream.wait();
      }
J
Jacek Czaja 已提交
294
    }
295
    return target_memory_p;
296 297
  }

298
  std::shared_ptr<dnnl::memory> AcquireBiasMemoryWithReorder(
299 300 301 302
      const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key,
      const framework::Tensor* bias) {
    const K* bias_data = bias->data<K>();
    auto user_bias_md = platform::MKLDNNMemDesc(
303
        phi::vectorize(bias->dims()), platform::MKLDNNGetDataType<K>(),
304 305 306 307
        MKLDNNMemoryFormat::x);
    return this->AcquireMemoryWithReorder(
        dev_ctx, user_bias_md, this->fwd_pd_->bias_desc(),
        platform::to_void_cast<K>(bias_data), key, "@bias_mem_p", is_test_);
308
  }
309 310 311

 private:
  const bool is_test_;
312
};
J
Jacek Czaja 已提交
313

314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
template <typename T, typename K>
class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
                      platform::errors::PreconditionNotMet(
                          "Operator DNNL ConvTranspose must use CPUPlace"));
    const bool is_bfloat16 =
        ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16";
    const bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
    if (is_bfloat16) {
      if (force_fp32_output)
        Execute<float>(ctx);
      else
        Execute<platform::bfloat16>(ctx);
    } else {
      Execute<float>(ctx);
    }
  }
J
Jacek Czaja 已提交
333

334 335 336 337 338
  template <typename T_out>
  void Execute(const framework::ExecutionContext& ctx) const {
    auto& dev_ctx =
        ctx.template device_context<platform::MKLDNNDeviceContext>();
    const auto& mkldnn_engine = dev_ctx.GetEngine();
J
Jacek Czaja 已提交
339

340 341 342 343 344
    const auto* input = ctx.Input<Tensor>("Input");
    const auto* filter = ctx.Input<Tensor>("Filter");
    const auto* bias =
        ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
    auto* output = ctx.Output<Tensor>("Output");
345 346
    ConvTransposeMKLDNNHandlerT<T, K, T_out> handler(ctx, mkldnn_engine, input,
                                                     filter, bias, output);
347
    auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
348 349 350 351 352
    // Caching Key for weights is needed
    std::string key = platform::CreateKey(dev_ctx, ctx.InputName("Input"),
                                          ctx.InputName("Filter"),
                                          (bias ? ctx.InputName("Bias") : ""));
    key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
353
    auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
354
        dev_ctx, key, filter, ctx.Attr<int>("groups"));
355 356 357 358 359 360

    std::shared_ptr<dnnl::memory> dst_memory_p =
        handler.template AcquireDstMemory<T_out>(output);
    auto conv_p = handler.AcquireForwardPrimitive();

    std::unordered_map<int, dnnl::memory> args = {
361 362 363
        {DNNL_ARG_SRC, *src_memory_p},
        {DNNL_ARG_WEIGHTS, *weights_memory_p},
        {DNNL_ARG_DST, *dst_memory_p}};
A
Adam 已提交
364

J
Jacek Czaja 已提交
365
    if (bias) {
366 367
      auto bias_memory_p =
          handler.AcquireBiasMemoryWithReorder(dev_ctx, key, bias);
368
      args.insert({DNNL_ARG_BIAS, *bias_memory_p});
J
Jacek Czaja 已提交
369
    }
370 371
    auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
    conv_p->execute(astream, args);
A
Adam 已提交
372
    astream.wait();
373 374
    output->set_layout(DataLayout::kMKLDNN);
    output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
J
Jacek Czaja 已提交
375 376 377 378 379 380 381 382
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

383 384 385 386
REGISTER_OP_KERNEL(
    conv2d_transpose, MKLDNN, ::paddle::platform::CPUPlace,
    ops::ConvTransposeMKLDNNOpKernel<float, float>,
    ops::ConvTransposeMKLDNNOpKernel<paddle::platform::bfloat16, float>);