mkldnn_helper.h 20.4 KB
Newer Older
1
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
T
tensor-tang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

16
#include <algorithm>
J
Jacek Czaja 已提交
17
#include <iostream>
P
Physher 已提交
18
#include <memory>
J
Jacek Czaja 已提交
19
#include <sstream>
G
gongweibao 已提交
20
#include <string>
21
#include <utility>
22
#include <vector>
23
#include "dnnl.hpp"
24
#include "paddle/fluid/framework/operator.h"
M
mozga-intel 已提交
25
#include "paddle/fluid/platform/place.h"
26
#include "paddle/fluid/platform/profiler.h"
T
tensor-tang 已提交
27
namespace paddle {
28
#ifdef PADDLE_WITH_MKLDNN
29
using MKLDNNMemoryFormat = dnnl::memory::format_tag;
30
#endif
T
tensor-tang 已提交
31 32
namespace platform {

33 34 35 36 37 38
using MKLDNNStream = dnnl::stream;
using MKLDNNEngine = dnnl::engine;
using MKLDNNMemory = dnnl::memory;
using MKLDNNMemoryDescriptor = dnnl::memory::desc;
using MKLDNNPrimitive = dnnl::primitive;
using MKLDNNPrimitiveDesc = dnnl::handle<dnnl_primitive_desc_t>;
T
tensor-tang 已提交
39

40 41 42 43 44
typedef std::unique_ptr<MKLDNNStream> MKLDNNStreamPtr;
typedef std::unique_ptr<MKLDNNEngine> MKLDNNEnginePtr;
typedef std::unique_ptr<MKLDNNMemory> MKLDNNMemoryPtr;
typedef std::unique_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr;
typedef std::unique_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr;
T
tensor-tang 已提交
45

46 47 48 49 50
template <typename Type>
void* to_void_cast(const Type* t) {
  return static_cast<void*>(const_cast<Type*>(t));
}

K
Krzysztof Binias 已提交
51 52 53 54 55
template <typename Type>
void* to_void_reinterpret_cast(const Type* t) {
  return reinterpret_cast<void*>(const_cast<Type*>(t));
}

56 57 58 59 60 61 62 63 64
template <class Type>
using tf_desc = typename Type::desc;

template <class Type>
using tf_pd = typename Type::primitive_desc;

template <typename Type, typename Engine, typename... Args>
std::shared_ptr<tf_pd<Type>> MKLDNNFwdPrimitiveDesc(const Engine& e,
                                                    Args&&... args) {
65
  auto desc = tf_desc<Type>(dnnl::prop_kind::forward, (args)...);
66 67 68 69 70 71 72 73 74 75 76
  auto pd = new tf_pd<Type>(desc, e);
  return std::shared_ptr<tf_pd<Type>>(pd);
}

template <typename Type, typename Engine, typename Primitive, typename... Args>
tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
                                   Args&&... args) {
  auto desc = tf_desc<Type>(args...);
  return tf_pd<Type>(desc, e, p);
}

77 78 79
inline void MatchShapeToLayout(framework::Tensor* tensor_in,
                               framework::DataLayout from,
                               framework::DataLayout to) {
80 81 82
  // In these data layouts, channel dimension is either on 2nd position: nChw or
  // at last nhwC, so for dim==2 these layouts are the same and nothing should
  // be done. Similarly for dim==1 when you have just one possible combination.
83 84 85 86
  if (tensor_in->dims().size() < 3) {
    return;
  }

J
Jacek Czaja 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
  auto print_dims = [](const std::vector<int>& dims) {
    std::ostringstream oss;

    if (!dims.empty()) {
      oss << "[";
      // Convert all but the last element to avoid a trailing ","
      std::copy(dims.begin(), dims.end() - 1,
                std::ostream_iterator<int>(oss, ","));

      // Now add the last element with no delimiter
      oss << dims.back() << "]";
    }

    return oss.str();
  };

103 104 105 106 107 108
  switch (from) {
    case framework::DataLayout::kMKLDNN:
      if (to == framework::DataLayout::kNHWC) {
        auto dims = framework::vectorize<int>(tensor_in->dims());
        std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
        tensor_in->Resize(framework::make_ddim(dims));
J
Jacek Czaja 已提交
109 110
        VLOG(3) << "Rotating Shape from: kMKLDNN to: kNHWC output_shape"
                << print_dims(dims);
111 112 113 114 115 116 117
      }
      break;
    case framework::DataLayout::kNHWC:
      if (to == framework::DataLayout::kMKLDNN) {
        auto dims = framework::vectorize<int>(tensor_in->dims());
        std::rotate(dims.begin() + 1, dims.end() - 1, dims.end());
        tensor_in->Resize(framework::make_ddim(dims));
J
Jacek Czaja 已提交
118 119
        VLOG(3) << "Rotating Shape from: kNHWC to: kMKLDNN output_shape"
                << print_dims(dims);
120 121 122 123 124 125 126
      }
      break;
    default:
      break;
  }
}

127 128 129 130 131
struct mkldnn_dummy_primitive {
  struct primitive_desc {};
  struct desc {};
};

132 133 134 135
inline dnnl::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
                                        dnnl::memory::data_type data_type,
                                        MKLDNNMemoryFormat format) {
  return dnnl::memory::desc({dims}, data_type, format);
136 137
}

138 139
inline void ClearMKLDNNCache(const platform::Place& place,
                             void* ptr = nullptr) {
140 141 142 143 144
  // Clear mkl-dnn cache,
  if (platform::is_cpu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::MKLDNNDeviceContext* dev_ctx =
        (platform::MKLDNNDeviceContext*)pool.Get(place);
145
    dev_ctx->ResetBlobMap(ptr);
146 147 148 149 150
    platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
        paddle::framework::DataLayout::kNCHW);
  }
}

151 152 153 154 155 156 157 158 159 160
inline void DontClearMKLDNNCache(const platform::Place& place) {
  // Clear mkl-dnn cache,
  if (platform::is_cpu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::MKLDNNDeviceContext* dev_ctx =
        (platform::MKLDNNDeviceContext*)pool.Get(place);
    dev_ctx->BlockNextCacheClearing();
  }
}

161
template <typename Type>
162 163
dnnl::memory::data_type MKLDNNGetDataType() {
  return dnnl::memory::data_type::undef;
164 165 166
}

template <>
167 168
inline dnnl::memory::data_type MKLDNNGetDataType<float>() {
  return dnnl::memory::data_type::f32;
169 170
}
template <>
171 172
inline dnnl::memory::data_type MKLDNNGetDataType<int32_t>() {
  return dnnl::memory::data_type::s32;
173
}
P
Physher 已提交
174
template <>
175 176
inline dnnl::memory::data_type MKLDNNGetDataType<int8_t>() {
  return dnnl::memory::data_type::s8;
P
Physher 已提交
177 178
}
template <>
179 180
inline dnnl::memory::data_type MKLDNNGetDataType<uint8_t>() {
  return dnnl::memory::data_type::u8;
P
Physher 已提交
181 182
}

183
template <>
184 185
inline dnnl::memory::data_type MKLDNNGetDataType<paddle::platform::bfloat16>() {
  return dnnl::memory::data_type::bf16;
186 187
}

188 189 190
inline void Reorder(dnnl::memory src, dnnl::memory dst,
                    const dnnl::engine& engine) {
  auto reorder_prim = dnnl::reorder(src, dst);
191
  auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
192 193
  platform::RecordEvent record_reorder("int_reorder",
                                       platform::EventRole::kUniqueOp);
A
Adam 已提交
194 195
  reorder_prim.execute(astream, src, dst);
  astream.wait();
M
mozga-intel 已提交
196 197
}

198
inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
A
Adam 已提交
199 200 201 202 203 204 205
  auto ndims = mem_desc.data.ndims;
  auto strides = mem_desc.data.format_desc.blocking.strides;
  auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks;
  auto inner_blks = mem_desc.data.format_desc.blocking.inner_blks;
  auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs;

  if (ndims == 1) {
206
    return dnnl::memory::format_tag::x;
A
Adam 已提交
207 208 209
  } else if (ndims == 2) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1]) {
210
        return dnnl::memory::format_tag::nc;
A
Adam 已提交
211
      } else {
212
        return dnnl::memory::format_tag::cn;
A
Adam 已提交
213 214 215 216 217
      }
    }
  } else if (ndims == 3) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
218
        return dnnl::memory::format_tag::ncw;
A
Adam 已提交
219
      } else if (strides[1] >= strides[0] && strides[0] >= strides[2]) {
220
        return dnnl::memory::format_tag::ntc;
A
Adam 已提交
221
      } else {
222
        return dnnl::memory::format_tag::nwc;
A
Adam 已提交
223 224 225 226 227 228
      }
    }
  } else if (ndims == 4) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3]) {
229
        return dnnl::memory::format_tag::nchw;
230 231
      } else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
                 strides[1] >= strides[0]) {
232
        return dnnl::memory::format_tag::cdba;
233 234 235
      } else if (strides[3] >= strides[2] && strides[2] >= strides[0] &&
                 strides[0] >= strides[1]) {
        return dnnl::memory::format_tag::dcab;
A
Adam 已提交
236
      } else {
237
        return dnnl::memory::format_tag::nhwc;
A
Adam 已提交
238 239 240
      }
    } else if (inner_nblks == 1) {
      if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
241
        return dnnl::memory::format_tag::nChw16c;
A
Adam 已提交
242
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
243
        return dnnl::memory::format_tag::nChw8c;
A
Adam 已提交
244 245 246
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[1]) {
247
          return dnnl::memory::format_tag::Acdb8a;
A
Adam 已提交
248 249
        }
      } else if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
250
        return dnnl::memory::format_tag::nChw4c;
A
Adam 已提交
251 252 253
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[1]) {
254
          return dnnl::memory::format_tag::Acdb16a;
A
Adam 已提交
255 256 257 258 259
        }
      }
    } else if (inner_nblks == 2) {
      if (inner_blks[0] == 16 && inner_blks[1] == 16) {
        if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
260
          return dnnl::memory::format_tag::OIhw16i16o;
A
Adam 已提交
261 262 263
        }
      } else if (inner_blks[0] == 8 && inner_blks[1] == 8) {
        if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
264
          return dnnl::memory::format_tag::OIhw8i8o;
A
Adam 已提交
265 266 267 268 269 270 271
        }
      }
    }
  } else if (ndims == 5) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3] && strides[3] >= strides[4]) {
272 273 274 275 276 277 278
        return dnnl::memory::format_tag::abcde;
      } else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
                 strides[1] >= strides[3] && strides[3] >= strides[4]) {
        return dnnl::memory::format_tag::acbde;
      } else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
                 strides[3] >= strides[4] && strides[4] >= strides[1]) {
        return dnnl::memory::format_tag::acdeb;
A
Adam 已提交
279 280 281 282 283
      }
    } else if (inner_nblks == 1) {
      if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[4] && strides[4] >= strides[1]) {
284
          return dnnl::memory::format_tag::Acdeb8a;
A
Adam 已提交
285
        }
286 287
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
288
          return dnnl::memory::format_tag::Abcde8a;
289
        }
A
Adam 已提交
290 291 292
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
293
          return dnnl::memory::format_tag::aBcde8b;
A
Adam 已提交
294 295 296 297
        }
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[4] && strides[4] >= strides[1]) {
298
          return dnnl::memory::format_tag::Acdeb16a;
A
Adam 已提交
299
        }
300 301
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
302
          return dnnl::memory::format_tag::Abcde16a;
303
        }
A
Adam 已提交
304 305 306
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
307
          return dnnl::memory::format_tag::aBcde16b;
A
Adam 已提交
308 309 310 311 312 313 314 315
        }
      }
    }
  } else if (ndims == 6) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3] && strides[3] >= strides[4] &&
          strides[4] >= strides[5]) {
316
        return dnnl::memory::format_tag::abcdef;
317 318 319 320
      } else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
                 strides[1] >= strides[3] && strides[3] >= strides[4] &&
                 strides[4] >= strides[5]) {
        return dnnl::memory::format_tag::acbdef;
A
Adam 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
      }
    }
  }
  // DEBUG CODE - KEEP UNTILL TENSOR.MEMORY_DESC IMPLEMENTED
  // std::cout<<"@@@@@@@@@@ UNDEFINED FORMAT @@@@@@@@@@@@@@@@@@@"<<std::endl;
  // std::cout<<"NDIMS: "<<ndims<<std::endl;
  // std::cout<<"INNER_NBLKS: "<<inner_nblks<<std::endl;
  // for (int i=0;i<ndims;++i) {
  //   std::cout<<"STRIDE["<<i<<"]: "<<strides[i]<<std::endl;
  // }
  // for (int i=0;i<inner_nblks;++i) {
  //   std::cout<<"INNER_BLKS["<<i<<"]: "<<inner_blks[i]<<std::endl;
  // }
  // for (int i=0;i<inner_nblks;++i) {
  //   std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl;
  // }
337
  return dnnl::memory::format_tag::undef;
M
mozga-intel 已提交
338 339
}

340
inline dnnl::memory::format_tag GetMKLDNNFormat(const dnnl::memory memory) {
A
Adam 已提交
341 342
  auto mem_desc = memory.get_desc();
  return GetMKLDNNFormat(mem_desc);
343 344
}

345
inline dnnl::memory::format_tag GetPlainMKLDNNFormat(int tensor_rank) {
346 347
  switch (tensor_rank) {
    case 1:
348
      return dnnl::memory::format_tag::a;
349 350
      break;
    case 2:
351
      return dnnl::memory::format_tag::ab;
352 353
      break;
    case 3:
354
      return dnnl::memory::format_tag::abc;
355 356
      break;
    case 4:
357
      return dnnl::memory::format_tag::abcd;
358 359
      break;
    case 5:
360
      return dnnl::memory::format_tag::abcde;
361 362
      break;
    case 6:
363
      return dnnl::memory::format_tag::abcdef;
364 365
      break;
    case 7:
366
      return dnnl::memory::format_tag::abcdefg;
367 368
      break;
    case 8:
369
      return dnnl::memory::format_tag::abcdefgh;
370 371
      break;
    case 9:
372
      return dnnl::memory::format_tag::abcdefghi;
373 374 375 376 377 378 379 380 381
      break;
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Paddle support tensors with rank in range <1, 9>, but received "
          "tensor with rank: %d",
          tensor_rank));
  }
}

382 383
inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size,
                                              MKLDNNMemoryFormat data_format) {
384
  if (dims_size == 1) {
385
    return MKLDNNMemoryFormat::x;
386
  } else if (dims_size == 2) {
387
    return MKLDNNMemoryFormat::nc;
388
  } else if (dims_size == 3) {
389 390 391 392
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::ncw;
    } else if (data_format == MKLDNNMemoryFormat::nhwc) {
      return MKLDNNMemoryFormat::nwc;
393
    }
394
  } else if (dims_size == 4) {
395 396
    if (data_format == MKLDNNMemoryFormat::goihw) {
      return MKLDNNMemoryFormat::oihw;
397
    }
398
  } else if (dims_size == 5) {
399 400
    if (data_format == MKLDNNMemoryFormat::goidhw) {
      return MKLDNNMemoryFormat::oidhw;
401
    }
402 403 404 405
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::ncdhw;
    } else if (data_format == MKLDNNMemoryFormat::nhwc) {
      return MKLDNNMemoryFormat::ndhwc;
406
    }
407
  } else if (dims_size == 6) {
408 409 410
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::abcdef;
    }
411 412 413 414
  }
  return data_format;
}

415
inline MKLDNNMemoryFormat data_format_to_memory_format(
416 417 418
    const std::string& data_format) {
  switch (framework::StringToDataLayout(data_format)) {
    case framework::DataLayout::kNHWC:
419
      return MKLDNNMemoryFormat::nhwc;
420
    case framework::DataLayout::kNCHW:
421
      return MKLDNNMemoryFormat::nchw;
422
    default:
423
      return MKLDNNMemoryFormat::any;
424 425 426
  }
}

427
inline MKLDNNMemoryFormat StringToMKLDNNFormat(std::string* format) {
428 429 430
  std::transform(format->begin(), format->end(), format->begin(), ::tolower);

  if (!format->compare("nchw")) {
431
    return MKLDNNMemoryFormat::nchw;
432
  } else if (!format->compare("nchw16c")) {
433
    return MKLDNNMemoryFormat::nChw16c;
434
  } else if (!format->compare("nchw8c")) {
435
    return MKLDNNMemoryFormat::nChw8c;
436
  } else if (!format->compare("nhwc")) {
437
    return MKLDNNMemoryFormat::nhwc;
438
  } else {
439
    return MKLDNNMemoryFormat::any;
440 441 442
  }
}

A
Adam 已提交
443 444 445 446 447
inline std::string ThreadIDasStr(void) {
  return std::to_string(
      std::hash<std::thread::id>()(std::this_thread::get_id()));
}

448 449 450
template <typename T>
inline void AppendKey(std::string* key, const T& num) {
  key->append(std::to_string(num));
A
Adam 已提交
451 452
}

A
Adam 已提交
453 454
template <>
inline void AppendKey(std::string* key,
455
                      const dnnl::memory::format_tag& format) {
A
Adam 已提交
456 457 458 459 460
  key->append(std::to_string(static_cast<int>(format)));
}

template <>
inline void AppendKey(std::string* key,
461
                      const dnnl::memory::data_type& data_type) {
A
Adam 已提交
462 463 464 465
  key->append(std::to_string(static_cast<int>(data_type)));
}

template <>
466
inline void AppendKey(std::string* key, const dnnl::algorithm& algorithm) {
A
Adam 已提交
467 468 469 470 471
  key->append(std::to_string(static_cast<int>(algorithm)));
}

template <>
inline void AppendKey(std::string* key,
472
                      const dnnl::normalization_flags& flags) {
A
Adam 已提交
473 474 475
  key->append(std::to_string(static_cast<int>(flags)));
}

476 477
inline void AppendKey(std::string* key, const std::string& str) {
  key->append(str);
A
Adam 已提交
478 479
}

480
inline void AppendKey(std::string* key, const char* str) { key->append(str); }
A
Adam 已提交
481

A
Adam 已提交
482 483
template <typename T>
inline void AppendKey(std::string* key, const std::vector<T>& dims) {
484
  for (size_t i = 0; i < dims.size(); i++) {
A
Adam 已提交
485 486 487 488
    AppendKey(key, std::to_string(dims[i]));
  }
}

489 490 491 492
// If MKLDNN build and CPU place then register suffix in DeviceContext
inline void AttachPointerHashToMKLDNNKey(void* ptr,
                                         const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
J
Jacek Czaja 已提交
493 494 495 496 497 498 499 500 501 502 503 504 505
    // Static vars will remember first executor and its thread
    // so both of them need to be processed by the same thread within
    // critical section
    static std::mutex static_vars_barrier;
    static_vars_barrier.lock();
    static auto first_exec = ptr;
    static auto first_thread = ThreadIDasStr();
    static_vars_barrier.unlock();

    if (first_exec != ptr) {
      paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix(
          "E" + std::to_string(reinterpret_cast<uintptr_t>(ptr)));
    }
506 507 508
    // Let's register adress of current executor
    paddle::platform::MKLDNNDeviceContext::tls().set_curr_exec(ptr);

J
Jacek Czaja 已提交
509 510 511 512
    // For first thread
    if (first_thread == ThreadIDasStr()) {
      paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key();
    }
513 514 515
  }
}

516
template <typename... ArgTypes>
517 518
inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx,
                             ArgTypes&&... args) {
519
  std::string key;
520
  key.reserve(64);
521
  using expand_type = int[];
522
  expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...};
J
Jacek Czaja 已提交
523
  key += paddle::platform::MKLDNNDeviceContext::tls().get_key_suffix();
524 525 526
  return key;
}

527 528
inline std::string ExtendKeyWithThreadInfoIfNeeded(
    const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) {
J
Jacek Czaja 已提交
529 530
  return (paddle::platform::MKLDNNDeviceContext::tls().is_tid_used_in_key() ==
          true)
531 532 533 534
             ? key + "-t:" + ThreadIDasStr()
             : key;
}

A
Adam 已提交
535 536
inline std::vector<std::vector<int64_t>> ToMkldnnPadding(
    const std::vector<int64_t>& paddings) {
537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556
  if (paddings.size() == 6) {
    int padding_front = paddings[0];
    int padding_back = paddings[1];
    int padding_top = paddings[2];
    int padding_bottom = paddings[3];
    int padding_left = paddings[4];
    int padding_right = paddings[5];

    return {{padding_front, padding_top, padding_left},
            {padding_back, padding_bottom, padding_right}};
  } else {
    int padding_top = paddings[0];
    int padding_bottom = paddings[1];
    int padding_left = paddings[2];
    int padding_right = paddings[3];

    return {{padding_top, padding_left}, {padding_bottom, padding_right}};
  }
}

557 558 559 560 561 562 563 564 565 566 567 568 569
// The function adjusts the vector of weight dimensions for group convolutions
inline void GetGroupConvWeightsTz(std::vector<int64_t>& weights_tz,  // NOLINT
                                  const int groups) {
  if (groups > 1) {
    // if (is_conv3d) [o, i, d, h, w]->[g, o/g, i, d, h, w]
    // else [o, i, h, w] -> [g, o/g, i, h, w]
    weights_tz.push_back(0);
    std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end());
    weights_tz[0] = groups;
    weights_tz[1] = weights_tz[1] / groups;
  }
}

570 571 572 573 574
inline bool HasOpINT8DataType(const paddle::framework::OpDesc* op) {
  return (op->GetAttrIfExists<std::string>("mkldnn_data_type") == "int8" ||
          op->GetAttrIfExists<bool>("use_quantizer"));
}

575 576 577 578 579 580 581
inline bool HasOpBFLOAT16DataType(const paddle::framework::OpDesc* op) {
  return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "bfloat16";
}

inline bool HasOpFLOAT32DataType(const paddle::framework::OpDesc* op) {
  return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "float32";
}
A
Adam Osewski 已提交
582

A
Adam 已提交
583 584
enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP };

A
Adam Osewski 已提交
585 586 587 588 589
template <typename T>
bool constexpr is_int8() {
  return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}

T
tensor-tang 已提交
590 591
}  // namespace platform
}  // namespace paddle