mkldnn_helper.h 21.6 KB
Newer Older
1
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
T
tensor-tang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

16
#include <algorithm>
J
Jacek Czaja 已提交
17
#include <iostream>
P
Physher 已提交
18
#include <memory>
J
Jacek Czaja 已提交
19
#include <sstream>
G
gongweibao 已提交
20
#include <string>
21
#include <utility>
22
#include <vector>
23

24
#include "dnnl.hpp"
25
#include "paddle/fluid/framework/operator.h"
M
mozga-intel 已提交
26
#include "paddle/fluid/platform/place.h"
C
chenjian 已提交
27
#include "paddle/fluid/platform/profiler/event_tracing.h"
T
tensor-tang 已提交
28
namespace paddle {
29
#ifdef PADDLE_WITH_MKLDNN
30
using MKLDNNMemoryFormat = dnnl::memory::format_tag;
31
#endif
T
tensor-tang 已提交
32 33
namespace platform {

34 35 36 37 38 39
using MKLDNNStream = dnnl::stream;
using MKLDNNEngine = dnnl::engine;
using MKLDNNMemory = dnnl::memory;
using MKLDNNMemoryDescriptor = dnnl::memory::desc;
using MKLDNNPrimitive = dnnl::primitive;
using MKLDNNPrimitiveDesc = dnnl::handle<dnnl_primitive_desc_t>;
T
tensor-tang 已提交
40

41 42 43 44 45
typedef std::unique_ptr<MKLDNNStream> MKLDNNStreamPtr;
typedef std::unique_ptr<MKLDNNEngine> MKLDNNEnginePtr;
typedef std::unique_ptr<MKLDNNMemory> MKLDNNMemoryPtr;
typedef std::unique_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr;
typedef std::unique_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr;
T
tensor-tang 已提交
46

47 48 49 50 51
template <typename Type>
void* to_void_cast(const Type* t) {
  return static_cast<void*>(const_cast<Type*>(t));
}

K
Krzysztof Binias 已提交
52 53 54 55 56
template <typename Type>
void* to_void_reinterpret_cast(const Type* t) {
  return reinterpret_cast<void*>(const_cast<Type*>(t));
}

57 58 59 60 61 62 63 64 65
template <class Type>
using tf_desc = typename Type::desc;

template <class Type>
using tf_pd = typename Type::primitive_desc;

template <typename Type, typename Engine, typename... Args>
std::shared_ptr<tf_pd<Type>> MKLDNNFwdPrimitiveDesc(const Engine& e,
                                                    Args&&... args) {
66
  auto desc = tf_desc<Type>(dnnl::prop_kind::forward, (args)...);
67 68 69 70 71 72 73 74 75 76 77
  auto pd = new tf_pd<Type>(desc, e);
  return std::shared_ptr<tf_pd<Type>>(pd);
}

template <typename Type, typename Engine, typename Primitive, typename... Args>
tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
                                   Args&&... args) {
  auto desc = tf_desc<Type>(args...);
  return tf_pd<Type>(desc, e, p);
}

78 79 80
inline void MatchShapeToLayout(framework::Tensor* tensor_in,
                               framework::DataLayout from,
                               framework::DataLayout to) {
81 82 83
  // In these data layouts, channel dimension is either on 2nd position: nChw or
  // at last nhwC, so for dim==2 these layouts are the same and nothing should
  // be done. Similarly for dim==1 when you have just one possible combination.
84 85 86 87
  if (tensor_in->dims().size() < 3) {
    return;
  }

J
Jacek Czaja 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
  auto print_dims = [](const std::vector<int>& dims) {
    std::ostringstream oss;

    if (!dims.empty()) {
      oss << "[";
      // Convert all but the last element to avoid a trailing ","
      std::copy(dims.begin(), dims.end() - 1,
                std::ostream_iterator<int>(oss, ","));

      // Now add the last element with no delimiter
      oss << dims.back() << "]";
    }

    return oss.str();
  };

104 105
  switch (from) {
    case framework::DataLayout::kMKLDNN:
106 107
      if ((to == framework::DataLayout::kNHWC) ||
          (to == framework::DataLayout::kNDHWC)) {
108
        auto dims = phi::vectorize<int>(tensor_in->dims());
109
        std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
110
        tensor_in->Resize(phi::make_ddim(dims));
111
        VLOG(3) << "Rotating Shape from: kMKLDNN to: kNHWC/kNDHWC output_shape"
J
Jacek Czaja 已提交
112
                << print_dims(dims);
113 114 115
      }
      break;
    case framework::DataLayout::kNHWC:
116
    case framework::DataLayout::kNDHWC:
117
      if (to == framework::DataLayout::kMKLDNN) {
118
        auto dims = phi::vectorize<int>(tensor_in->dims());
119
        std::rotate(dims.begin() + 1, dims.end() - 1, dims.end());
120
        tensor_in->Resize(phi::make_ddim(dims));
121
        VLOG(3) << "Rotating Shape from: kNHWC/kNDHWC to: kMKLDNN output_shape"
J
Jacek Czaja 已提交
122
                << print_dims(dims);
123 124 125 126 127 128 129
      }
      break;
    default:
      break;
  }
}

130 131 132 133 134
struct mkldnn_dummy_primitive {
  struct primitive_desc {};
  struct desc {};
};

135 136 137 138
inline dnnl::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
                                        dnnl::memory::data_type data_type,
                                        MKLDNNMemoryFormat format) {
  return dnnl::memory::desc({dims}, data_type, format);
139 140
}

141 142
inline void ClearMKLDNNCache(const platform::Place& place,
                             void* ptr = nullptr) {
143 144 145 146 147
  // Clear mkl-dnn cache,
  if (platform::is_cpu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::MKLDNNDeviceContext* dev_ctx =
        (platform::MKLDNNDeviceContext*)pool.Get(place);
148
    dev_ctx->ResetBlobMap(ptr);
149 150 151 152 153
    platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
        paddle::framework::DataLayout::kNCHW);
  }
}

154 155 156 157 158 159 160 161 162 163
inline void DontClearMKLDNNCache(const platform::Place& place) {
  // Clear mkl-dnn cache,
  if (platform::is_cpu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::MKLDNNDeviceContext* dev_ctx =
        (platform::MKLDNNDeviceContext*)pool.Get(place);
    dev_ctx->BlockNextCacheClearing();
  }
}

164
template <typename Type>
165 166
dnnl::memory::data_type MKLDNNGetDataType() {
  return dnnl::memory::data_type::undef;
167 168 169
}

template <>
170 171
inline dnnl::memory::data_type MKLDNNGetDataType<float>() {
  return dnnl::memory::data_type::f32;
172 173
}
template <>
174 175
inline dnnl::memory::data_type MKLDNNGetDataType<int32_t>() {
  return dnnl::memory::data_type::s32;
176
}
P
Physher 已提交
177
template <>
178 179
inline dnnl::memory::data_type MKLDNNGetDataType<int8_t>() {
  return dnnl::memory::data_type::s8;
P
Physher 已提交
180 181
}
template <>
182 183
inline dnnl::memory::data_type MKLDNNGetDataType<uint8_t>() {
  return dnnl::memory::data_type::u8;
P
Physher 已提交
184 185
}

186
template <>
187 188
inline dnnl::memory::data_type MKLDNNGetDataType<paddle::platform::bfloat16>() {
  return dnnl::memory::data_type::bf16;
189 190
}

191 192 193
inline void Reorder(dnnl::memory src, dnnl::memory dst,
                    const dnnl::engine& engine) {
  auto reorder_prim = dnnl::reorder(src, dst);
194
  auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
195
  platform::RecordEvent record_reorder("int_reorder",
C
chenjian 已提交
196 197
                                       platform::TracerEventType::UserDefined,
                                       2, platform::EventRole::kUniqueOp);
A
Adam 已提交
198 199
  reorder_prim.execute(astream, src, dst);
  astream.wait();
M
mozga-intel 已提交
200 201
}

202
inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
A
Adam 已提交
203 204 205 206 207 208 209
  auto ndims = mem_desc.data.ndims;
  auto strides = mem_desc.data.format_desc.blocking.strides;
  auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks;
  auto inner_blks = mem_desc.data.format_desc.blocking.inner_blks;
  auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs;

  if (ndims == 1) {
210
    return dnnl::memory::format_tag::x;
A
Adam 已提交
211 212 213
  } else if (ndims == 2) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1]) {
214
        return dnnl::memory::format_tag::nc;
A
Adam 已提交
215
      } else {
216
        return dnnl::memory::format_tag::cn;
A
Adam 已提交
217 218 219 220 221
      }
    }
  } else if (ndims == 3) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
222
        return dnnl::memory::format_tag::ncw;
A
Adam 已提交
223
      } else if (strides[1] >= strides[0] && strides[0] >= strides[2]) {
224
        return dnnl::memory::format_tag::ntc;
A
Adam 已提交
225
      } else {
226
        return dnnl::memory::format_tag::nwc;
A
Adam 已提交
227 228 229 230 231 232
      }
    }
  } else if (ndims == 4) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3]) {
233
        return dnnl::memory::format_tag::nchw;
234 235
      } else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
                 strides[1] >= strides[0]) {
236
        return dnnl::memory::format_tag::cdba;
237 238 239
      } else if (strides[3] >= strides[2] && strides[2] >= strides[0] &&
                 strides[0] >= strides[1]) {
        return dnnl::memory::format_tag::dcab;
A
Adam 已提交
240
      } else {
241
        return dnnl::memory::format_tag::nhwc;
A
Adam 已提交
242 243 244
      }
    } else if (inner_nblks == 1) {
      if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
245
        return dnnl::memory::format_tag::nChw16c;
A
Adam 已提交
246
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
247
        return dnnl::memory::format_tag::nChw8c;
A
Adam 已提交
248 249 250
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[1]) {
251
          return dnnl::memory::format_tag::Acdb8a;
A
Adam 已提交
252 253
        }
      } else if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
254
        return dnnl::memory::format_tag::nChw4c;
A
Adam 已提交
255 256 257
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[1]) {
258
          return dnnl::memory::format_tag::Acdb16a;
A
Adam 已提交
259 260 261 262 263
        }
      }
    } else if (inner_nblks == 2) {
      if (inner_blks[0] == 16 && inner_blks[1] == 16) {
        if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
264
          return dnnl::memory::format_tag::OIhw16i16o;
A
Adam 已提交
265 266 267
        }
      } else if (inner_blks[0] == 8 && inner_blks[1] == 8) {
        if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
268
          return dnnl::memory::format_tag::OIhw8i8o;
A
Adam 已提交
269 270 271 272 273 274 275
        }
      }
    }
  } else if (ndims == 5) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3] && strides[3] >= strides[4]) {
276 277 278 279 280 281 282
        return dnnl::memory::format_tag::abcde;
      } else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
                 strides[1] >= strides[3] && strides[3] >= strides[4]) {
        return dnnl::memory::format_tag::acbde;
      } else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
                 strides[3] >= strides[4] && strides[4] >= strides[1]) {
        return dnnl::memory::format_tag::acdeb;
A
Adam 已提交
283 284
      }
    } else if (inner_nblks == 1) {
285 286 287 288 289 290
      if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
          return dnnl::memory::format_tag::aBcde4b;
        }
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
A
Adam 已提交
291 292
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[4] && strides[4] >= strides[1]) {
293
          return dnnl::memory::format_tag::Acdeb8a;
A
Adam 已提交
294
        }
295 296
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
297
          return dnnl::memory::format_tag::Abcde8a;
298
        }
A
Adam 已提交
299 300 301
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
302
          return dnnl::memory::format_tag::aBcde8b;
A
Adam 已提交
303 304 305 306
        }
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[4] && strides[4] >= strides[1]) {
307
          return dnnl::memory::format_tag::Acdeb16a;
A
Adam 已提交
308
        }
309 310
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
311
          return dnnl::memory::format_tag::Abcde16a;
312
        }
A
Adam 已提交
313 314 315
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
316
          return dnnl::memory::format_tag::aBcde16b;
A
Adam 已提交
317 318 319 320 321 322 323 324
        }
      }
    }
  } else if (ndims == 6) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3] && strides[3] >= strides[4] &&
          strides[4] >= strides[5]) {
325
        return dnnl::memory::format_tag::abcdef;
326 327 328 329
      } else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
                 strides[1] >= strides[3] && strides[3] >= strides[4] &&
                 strides[4] >= strides[5]) {
        return dnnl::memory::format_tag::acbdef;
A
Adam 已提交
330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
      }
    }
  }
  // DEBUG CODE - KEEP UNTILL TENSOR.MEMORY_DESC IMPLEMENTED
  // std::cout<<"@@@@@@@@@@ UNDEFINED FORMAT @@@@@@@@@@@@@@@@@@@"<<std::endl;
  // std::cout<<"NDIMS: "<<ndims<<std::endl;
  // std::cout<<"INNER_NBLKS: "<<inner_nblks<<std::endl;
  // for (int i=0;i<ndims;++i) {
  //   std::cout<<"STRIDE["<<i<<"]: "<<strides[i]<<std::endl;
  // }
  // for (int i=0;i<inner_nblks;++i) {
  //   std::cout<<"INNER_BLKS["<<i<<"]: "<<inner_blks[i]<<std::endl;
  // }
  // for (int i=0;i<inner_nblks;++i) {
  //   std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl;
  // }
346
  return dnnl::memory::format_tag::undef;
M
mozga-intel 已提交
347 348
}

349
inline dnnl::memory::format_tag GetMKLDNNFormat(const dnnl::memory memory) {
A
Adam 已提交
350 351
  auto mem_desc = memory.get_desc();
  return GetMKLDNNFormat(mem_desc);
352 353
}

354
inline dnnl::memory::format_tag GetPlainMKLDNNFormat(int tensor_rank) {
355 356
  switch (tensor_rank) {
    case 1:
357
      return dnnl::memory::format_tag::a;
358
    case 2:
359
      return dnnl::memory::format_tag::ab;
360
    case 3:
361
      return dnnl::memory::format_tag::abc;
362
    case 4:
363
      return dnnl::memory::format_tag::abcd;
364
    case 5:
365
      return dnnl::memory::format_tag::abcde;
366
    case 6:
367
      return dnnl::memory::format_tag::abcdef;
368
    case 7:
369
      return dnnl::memory::format_tag::abcdefg;
370
    case 8:
371
      return dnnl::memory::format_tag::abcdefgh;
372
    case 9:
373
      return dnnl::memory::format_tag::abcdefghi;
374 375 376 377 378 379 380 381
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Paddle support tensors with rank in range <1, 9>, but received "
          "tensor with rank: %d",
          tensor_rank));
  }
}

382 383
inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size,
                                              MKLDNNMemoryFormat data_format) {
384
  if (dims_size == 1) {
385
    return MKLDNNMemoryFormat::x;
386
  } else if (dims_size == 2) {
387
    return MKLDNNMemoryFormat::nc;
388
  } else if (dims_size == 3) {
389 390 391 392
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::ncw;
    } else if (data_format == MKLDNNMemoryFormat::nhwc) {
      return MKLDNNMemoryFormat::nwc;
393
    }
394
  } else if (dims_size == 4) {
395 396
    if (data_format == MKLDNNMemoryFormat::goihw) {
      return MKLDNNMemoryFormat::oihw;
397
    }
398
  } else if (dims_size == 5) {
399 400
    if (data_format == MKLDNNMemoryFormat::goidhw) {
      return MKLDNNMemoryFormat::oidhw;
401
    }
402 403 404 405
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::ncdhw;
    } else if (data_format == MKLDNNMemoryFormat::nhwc) {
      return MKLDNNMemoryFormat::ndhwc;
406
    }
407
  } else if (dims_size == 6) {
408 409 410
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::abcdef;
    }
411 412 413 414
  }
  return data_format;
}

415
inline MKLDNNMemoryFormat data_format_to_memory_format(
416 417 418
    const std::string& data_format) {
  switch (framework::StringToDataLayout(data_format)) {
    case framework::DataLayout::kNHWC:
419
      return MKLDNNMemoryFormat::nhwc;
420
    case framework::DataLayout::kNCHW:
421
      return MKLDNNMemoryFormat::nchw;
422
    default:
423
      return MKLDNNMemoryFormat::any;
424 425 426
  }
}

427
inline MKLDNNMemoryFormat StringToMKLDNNFormat(std::string* format) {
428 429 430
  std::transform(format->begin(), format->end(), format->begin(), ::tolower);

  if (!format->compare("nchw")) {
431
    return MKLDNNMemoryFormat::nchw;
432
  } else if (!format->compare("nchw16c")) {
433
    return MKLDNNMemoryFormat::nChw16c;
434
  } else if (!format->compare("nchw8c")) {
435
    return MKLDNNMemoryFormat::nChw8c;
436
  } else if (!format->compare("nhwc")) {
437
    return MKLDNNMemoryFormat::nhwc;
438
  } else {
439
    return MKLDNNMemoryFormat::any;
440 441 442
  }
}

A
Adam 已提交
443 444 445 446 447
inline std::string ThreadIDasStr(void) {
  return std::to_string(
      std::hash<std::thread::id>()(std::this_thread::get_id()));
}

448 449 450
template <typename T>
inline void AppendKey(std::string* key, const T& num) {
  key->append(std::to_string(num));
A
Adam 已提交
451 452
}

A
Adam 已提交
453 454
template <>
inline void AppendKey(std::string* key,
455
                      const dnnl::memory::format_tag& format) {
A
Adam 已提交
456 457 458 459 460
  key->append(std::to_string(static_cast<int>(format)));
}

template <>
inline void AppendKey(std::string* key,
461
                      const dnnl::memory::data_type& data_type) {
A
Adam 已提交
462 463 464 465
  key->append(std::to_string(static_cast<int>(data_type)));
}

template <>
466
inline void AppendKey(std::string* key, const dnnl::algorithm& algorithm) {
A
Adam 已提交
467 468 469 470 471
  key->append(std::to_string(static_cast<int>(algorithm)));
}

template <>
inline void AppendKey(std::string* key,
472
                      const dnnl::normalization_flags& flags) {
A
Adam 已提交
473 474 475
  key->append(std::to_string(static_cast<int>(flags)));
}

476 477
inline void AppendKey(std::string* key, const std::string& str) {
  key->append(str);
A
Adam 已提交
478 479
}

480
inline void AppendKey(std::string* key, const char* str) { key->append(str); }
A
Adam 已提交
481

A
Adam 已提交
482 483
template <typename T>
inline void AppendKey(std::string* key, const std::vector<T>& dims) {
484
  for (size_t i = 0; i < dims.size(); i++) {
A
Adam 已提交
485 486 487 488
    AppendKey(key, std::to_string(dims[i]));
  }
}

489 490 491 492
// If MKLDNN build and CPU place then register suffix in DeviceContext
inline void AttachPointerHashToMKLDNNKey(void* ptr,
                                         const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
J
Jacek Czaja 已提交
493 494 495 496 497 498 499 500 501 502 503 504 505
    // Static vars will remember first executor and its thread
    // so both of them need to be processed by the same thread within
    // critical section
    static std::mutex static_vars_barrier;
    static_vars_barrier.lock();
    static auto first_exec = ptr;
    static auto first_thread = ThreadIDasStr();
    static_vars_barrier.unlock();

    if (first_exec != ptr) {
      paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix(
          "E" + std::to_string(reinterpret_cast<uintptr_t>(ptr)));
    }
506 507 508
    // Let's register adress of current executor
    paddle::platform::MKLDNNDeviceContext::tls().set_curr_exec(ptr);

J
Jacek Czaja 已提交
509 510 511 512
    // For first thread
    if (first_thread == ThreadIDasStr()) {
      paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key();
    }
513 514 515
  }
}

516
template <typename... ArgTypes>
517 518
inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx,
                             ArgTypes&&... args) {
519
  std::string key;
520
  key.reserve(64);
521
  using expand_type = int[];
522
  expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...};
J
Jacek Czaja 已提交
523
  key += paddle::platform::MKLDNNDeviceContext::tls().get_key_suffix();
524 525 526
  return key;
}

527 528
inline std::string ExtendKeyWithThreadInfoIfNeeded(
    const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) {
J
Jacek Czaja 已提交
529 530
  return (paddle::platform::MKLDNNDeviceContext::tls().is_tid_used_in_key() ==
          true)
531 532 533 534
             ? key + "-t:" + ThreadIDasStr()
             : key;
}

A
Adam 已提交
535 536
inline std::vector<std::vector<int64_t>> ToMkldnnPadding(
    const std::vector<int64_t>& paddings) {
537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556
  if (paddings.size() == 6) {
    int padding_front = paddings[0];
    int padding_back = paddings[1];
    int padding_top = paddings[2];
    int padding_bottom = paddings[3];
    int padding_left = paddings[4];
    int padding_right = paddings[5];

    return {{padding_front, padding_top, padding_left},
            {padding_back, padding_bottom, padding_right}};
  } else {
    int padding_top = paddings[0];
    int padding_bottom = paddings[1];
    int padding_left = paddings[2];
    int padding_right = paddings[3];

    return {{padding_top, padding_left}, {padding_bottom, padding_right}};
  }
}

557 558 559 560 561 562 563 564 565 566 567 568 569
// The function adjusts the vector of weight dimensions for group convolutions
inline void GetGroupConvWeightsTz(std::vector<int64_t>& weights_tz,  // NOLINT
                                  const int groups) {
  if (groups > 1) {
    // if (is_conv3d) [o, i, d, h, w]->[g, o/g, i, d, h, w]
    // else [o, i, h, w] -> [g, o/g, i, h, w]
    weights_tz.push_back(0);
    std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end());
    weights_tz[0] = groups;
    weights_tz[1] = weights_tz[1] / groups;
  }
}

J
Jacek Czaja 已提交
570 571 572 573
inline void RegisterModelLayout(
    std::vector<std::unique_ptr<framework::OperatorBase>>& ops,
    const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
L
Leo Chen 已提交
574
    VLOG(4) << "RegisterModelLayout for mkldnn";
J
Jacek Czaja 已提交
575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598
    auto check_attrib = [](std::unique_ptr<framework::OperatorBase>& op,
                           const std::string& attrib_name) -> bool {
      if (op->HasAttr(attrib_name)) {
        auto data_format = op->Attr<std::string>(attrib_name);
        platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
            data_format.compare("NHWC") == 0 ? framework::DataLayout::kNHWC
                                             : framework::DataLayout::kNCHW);
        return true;
      } else {
        return false;
      }
    };

    for (auto& op : ops) {
      if (check_attrib(op, std::string("data_format"))) {
        return;
      }
      if (check_attrib(op, std::string("data_layout"))) {
        return;
      }
    }
  }
}

599 600 601 602 603
inline bool HasOpINT8DataType(const paddle::framework::OpDesc* op) {
  return (op->GetAttrIfExists<std::string>("mkldnn_data_type") == "int8" ||
          op->GetAttrIfExists<bool>("use_quantizer"));
}

604 605 606 607 608 609 610
inline bool HasOpBFLOAT16DataType(const paddle::framework::OpDesc* op) {
  return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "bfloat16";
}

inline bool HasOpFLOAT32DataType(const paddle::framework::OpDesc* op) {
  return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "float32";
}
A
Adam Osewski 已提交
611

A
Adam 已提交
612 613
enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP };

A
Adam Osewski 已提交
614 615 616 617 618
template <typename T>
bool constexpr is_int8() {
  return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}

T
tensor-tang 已提交
619 620
}  // namespace platform
}  // namespace paddle