mkldnn_helper.h 20.3 KB
Newer Older
1
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
T
tensor-tang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

16
#include <algorithm>
J
Jacek Czaja 已提交
17
#include <iostream>
P
Physher 已提交
18
#include <memory>
J
Jacek Czaja 已提交
19
#include <sstream>
G
gongweibao 已提交
20
#include <string>
21
#include <utility>
22
#include <vector>
23
#include "dnnl.hpp"
24
#include "paddle/fluid/framework/operator.h"
M
mozga-intel 已提交
25
#include "paddle/fluid/platform/place.h"
26
#include "paddle/fluid/platform/profiler.h"
T
tensor-tang 已提交
27
namespace paddle {
28
#ifdef PADDLE_WITH_MKLDNN
29
using MKLDNNMemoryFormat = dnnl::memory::format_tag;
30
#endif
T
tensor-tang 已提交
31 32
namespace platform {

33 34 35 36 37 38
using MKLDNNStream = dnnl::stream;
using MKLDNNEngine = dnnl::engine;
using MKLDNNMemory = dnnl::memory;
using MKLDNNMemoryDescriptor = dnnl::memory::desc;
using MKLDNNPrimitive = dnnl::primitive;
using MKLDNNPrimitiveDesc = dnnl::handle<dnnl_primitive_desc_t>;
T
tensor-tang 已提交
39

40 41 42 43 44
typedef std::unique_ptr<MKLDNNStream> MKLDNNStreamPtr;
typedef std::unique_ptr<MKLDNNEngine> MKLDNNEnginePtr;
typedef std::unique_ptr<MKLDNNMemory> MKLDNNMemoryPtr;
typedef std::unique_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr;
typedef std::unique_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr;
T
tensor-tang 已提交
45

46 47 48 49 50
template <typename Type>
void* to_void_cast(const Type* t) {
  return static_cast<void*>(const_cast<Type*>(t));
}

K
Krzysztof Binias 已提交
51 52 53 54 55
template <typename Type>
void* to_void_reinterpret_cast(const Type* t) {
  return reinterpret_cast<void*>(const_cast<Type*>(t));
}

56 57 58 59 60 61 62 63 64
template <class Type>
using tf_desc = typename Type::desc;

template <class Type>
using tf_pd = typename Type::primitive_desc;

template <typename Type, typename Engine, typename... Args>
std::shared_ptr<tf_pd<Type>> MKLDNNFwdPrimitiveDesc(const Engine& e,
                                                    Args&&... args) {
65
  auto desc = tf_desc<Type>(dnnl::prop_kind::forward, (args)...);
66 67 68 69 70 71 72 73 74 75 76
  auto pd = new tf_pd<Type>(desc, e);
  return std::shared_ptr<tf_pd<Type>>(pd);
}

template <typename Type, typename Engine, typename Primitive, typename... Args>
tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
                                   Args&&... args) {
  auto desc = tf_desc<Type>(args...);
  return tf_pd<Type>(desc, e, p);
}

77 78 79
inline void MatchShapeToLayout(framework::Tensor* tensor_in,
                               framework::DataLayout from,
                               framework::DataLayout to) {
80 81 82
  // In these data layouts, channel dimension is either on 2nd position: nChw or
  // at last nhwC, so for dim==2 these layouts are the same and nothing should
  // be done. Similarly for dim==1 when you have just one possible combination.
83 84 85 86
  if (tensor_in->dims().size() < 3) {
    return;
  }

J
Jacek Czaja 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
  auto print_dims = [](const std::vector<int>& dims) {
    std::ostringstream oss;

    if (!dims.empty()) {
      oss << "[";
      // Convert all but the last element to avoid a trailing ","
      std::copy(dims.begin(), dims.end() - 1,
                std::ostream_iterator<int>(oss, ","));

      // Now add the last element with no delimiter
      oss << dims.back() << "]";
    }

    return oss.str();
  };

103 104 105 106 107 108
  switch (from) {
    case framework::DataLayout::kMKLDNN:
      if (to == framework::DataLayout::kNHWC) {
        auto dims = framework::vectorize<int>(tensor_in->dims());
        std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
        tensor_in->Resize(framework::make_ddim(dims));
J
Jacek Czaja 已提交
109 110
        VLOG(3) << "Rotating Shape from: kMKLDNN to: kNHWC output_shape"
                << print_dims(dims);
111 112 113 114 115 116 117
      }
      break;
    case framework::DataLayout::kNHWC:
      if (to == framework::DataLayout::kMKLDNN) {
        auto dims = framework::vectorize<int>(tensor_in->dims());
        std::rotate(dims.begin() + 1, dims.end() - 1, dims.end());
        tensor_in->Resize(framework::make_ddim(dims));
J
Jacek Czaja 已提交
118 119
        VLOG(3) << "Rotating Shape from: kNHWC to: kMKLDNN output_shape"
                << print_dims(dims);
120 121 122 123 124 125 126
      }
      break;
    default:
      break;
  }
}

127 128 129 130 131
struct mkldnn_dummy_primitive {
  struct primitive_desc {};
  struct desc {};
};

132 133 134 135
inline dnnl::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
                                        dnnl::memory::data_type data_type,
                                        MKLDNNMemoryFormat format) {
  return dnnl::memory::desc({dims}, data_type, format);
136 137
}

138 139
inline void ClearMKLDNNCache(const platform::Place& place,
                             void* ptr = nullptr) {
140 141 142 143 144
  // Clear mkl-dnn cache,
  if (platform::is_cpu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::MKLDNNDeviceContext* dev_ctx =
        (platform::MKLDNNDeviceContext*)pool.Get(place);
145
    dev_ctx->ResetBlobMap(ptr);
146 147 148 149 150
    platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
        paddle::framework::DataLayout::kNCHW);
  }
}

151 152 153 154 155 156 157 158 159 160
inline void DontClearMKLDNNCache(const platform::Place& place) {
  // Clear mkl-dnn cache,
  if (platform::is_cpu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::MKLDNNDeviceContext* dev_ctx =
        (platform::MKLDNNDeviceContext*)pool.Get(place);
    dev_ctx->BlockNextCacheClearing();
  }
}

161
template <typename Type>
162 163
dnnl::memory::data_type MKLDNNGetDataType() {
  return dnnl::memory::data_type::undef;
164 165 166
}

template <>
167 168
inline dnnl::memory::data_type MKLDNNGetDataType<float>() {
  return dnnl::memory::data_type::f32;
169 170
}
template <>
171 172
inline dnnl::memory::data_type MKLDNNGetDataType<int32_t>() {
  return dnnl::memory::data_type::s32;
173
}
P
Physher 已提交
174
template <>
175 176
inline dnnl::memory::data_type MKLDNNGetDataType<int8_t>() {
  return dnnl::memory::data_type::s8;
P
Physher 已提交
177 178
}
template <>
179 180
inline dnnl::memory::data_type MKLDNNGetDataType<uint8_t>() {
  return dnnl::memory::data_type::u8;
P
Physher 已提交
181 182
}

183
template <>
184 185
inline dnnl::memory::data_type MKLDNNGetDataType<paddle::platform::bfloat16>() {
  return dnnl::memory::data_type::bf16;
186 187
}

188 189 190
inline void Reorder(dnnl::memory src, dnnl::memory dst,
                    const dnnl::engine& engine) {
  auto reorder_prim = dnnl::reorder(src, dst);
191
  auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
192 193
  platform::RecordEvent record_reorder("int_reorder",
                                       platform::EventRole::kUniqueOp);
A
Adam 已提交
194 195
  reorder_prim.execute(astream, src, dst);
  astream.wait();
M
mozga-intel 已提交
196 197
}

198
inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
A
Adam 已提交
199 200 201 202 203 204 205
  auto ndims = mem_desc.data.ndims;
  auto strides = mem_desc.data.format_desc.blocking.strides;
  auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks;
  auto inner_blks = mem_desc.data.format_desc.blocking.inner_blks;
  auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs;

  if (ndims == 1) {
206
    return dnnl::memory::format_tag::x;
A
Adam 已提交
207 208 209
  } else if (ndims == 2) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1]) {
210
        return dnnl::memory::format_tag::nc;
A
Adam 已提交
211
      } else {
212
        return dnnl::memory::format_tag::cn;
A
Adam 已提交
213 214 215 216 217
      }
    }
  } else if (ndims == 3) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
218
        return dnnl::memory::format_tag::ncw;
A
Adam 已提交
219
      } else if (strides[1] >= strides[0] && strides[0] >= strides[2]) {
220
        return dnnl::memory::format_tag::ntc;
A
Adam 已提交
221
      } else {
222
        return dnnl::memory::format_tag::nwc;
A
Adam 已提交
223 224 225 226 227 228
      }
    }
  } else if (ndims == 4) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3]) {
229
        return dnnl::memory::format_tag::nchw;
230 231
      } else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
                 strides[1] >= strides[0]) {
232
        return dnnl::memory::format_tag::cdba;
233 234 235
      } else if (strides[3] >= strides[2] && strides[2] >= strides[0] &&
                 strides[0] >= strides[1]) {
        return dnnl::memory::format_tag::dcab;
A
Adam 已提交
236
      } else {
237
        return dnnl::memory::format_tag::nhwc;
A
Adam 已提交
238 239 240
      }
    } else if (inner_nblks == 1) {
      if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
241
        return dnnl::memory::format_tag::nChw16c;
A
Adam 已提交
242
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
243
        return dnnl::memory::format_tag::nChw8c;
A
Adam 已提交
244 245 246
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[1]) {
247
          return dnnl::memory::format_tag::Acdb8a;
A
Adam 已提交
248 249
        }
      } else if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
250
        return dnnl::memory::format_tag::nChw4c;
A
Adam 已提交
251 252 253
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[1]) {
254
          return dnnl::memory::format_tag::Acdb16a;
A
Adam 已提交
255 256 257 258 259
        }
      }
    } else if (inner_nblks == 2) {
      if (inner_blks[0] == 16 && inner_blks[1] == 16) {
        if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
260
          return dnnl::memory::format_tag::OIhw16i16o;
A
Adam 已提交
261 262 263
        }
      } else if (inner_blks[0] == 8 && inner_blks[1] == 8) {
        if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
264
          return dnnl::memory::format_tag::OIhw8i8o;
A
Adam 已提交
265 266 267 268 269 270 271
        }
      }
    }
  } else if (ndims == 5) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3] && strides[3] >= strides[4]) {
272 273 274 275 276 277 278
        return dnnl::memory::format_tag::abcde;
      } else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
                 strides[1] >= strides[3] && strides[3] >= strides[4]) {
        return dnnl::memory::format_tag::acbde;
      } else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
                 strides[3] >= strides[4] && strides[4] >= strides[1]) {
        return dnnl::memory::format_tag::acdeb;
A
Adam 已提交
279 280 281 282 283
      }
    } else if (inner_nblks == 1) {
      if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[4] && strides[4] >= strides[1]) {
284
          return dnnl::memory::format_tag::Acdeb8a;
A
Adam 已提交
285
        }
286 287
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
288
          return dnnl::memory::format_tag::Abcde8a;
289
        }
A
Adam 已提交
290 291 292
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
293
          return dnnl::memory::format_tag::aBcde8b;
A
Adam 已提交
294 295 296 297
        }
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[4] && strides[4] >= strides[1]) {
298
          return dnnl::memory::format_tag::Acdeb16a;
A
Adam 已提交
299
        }
300 301
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
302
          return dnnl::memory::format_tag::Abcde16a;
303
        }
A
Adam 已提交
304 305 306
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
307
          return dnnl::memory::format_tag::aBcde16b;
A
Adam 已提交
308 309 310 311 312 313 314 315
        }
      }
    }
  } else if (ndims == 6) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3] && strides[3] >= strides[4] &&
          strides[4] >= strides[5]) {
316
        return dnnl::memory::format_tag::abcdef;
317 318 319 320
      } else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
                 strides[1] >= strides[3] && strides[3] >= strides[4] &&
                 strides[4] >= strides[5]) {
        return dnnl::memory::format_tag::acbdef;
A
Adam 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
      }
    }
  }
  // DEBUG CODE - KEEP UNTILL TENSOR.MEMORY_DESC IMPLEMENTED
  // std::cout<<"@@@@@@@@@@ UNDEFINED FORMAT @@@@@@@@@@@@@@@@@@@"<<std::endl;
  // std::cout<<"NDIMS: "<<ndims<<std::endl;
  // std::cout<<"INNER_NBLKS: "<<inner_nblks<<std::endl;
  // for (int i=0;i<ndims;++i) {
  //   std::cout<<"STRIDE["<<i<<"]: "<<strides[i]<<std::endl;
  // }
  // for (int i=0;i<inner_nblks;++i) {
  //   std::cout<<"INNER_BLKS["<<i<<"]: "<<inner_blks[i]<<std::endl;
  // }
  // for (int i=0;i<inner_nblks;++i) {
  //   std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl;
  // }
337
  return dnnl::memory::format_tag::undef;
M
mozga-intel 已提交
338 339
}

340
inline dnnl::memory::format_tag GetMKLDNNFormat(const dnnl::memory memory) {
A
Adam 已提交
341 342
  auto mem_desc = memory.get_desc();
  return GetMKLDNNFormat(mem_desc);
343 344
}

345
inline dnnl::memory::format_tag GetPlainMKLDNNFormat(int tensor_rank) {
346 347
  switch (tensor_rank) {
    case 1:
348
      return dnnl::memory::format_tag::a;
349
    case 2:
350
      return dnnl::memory::format_tag::ab;
351
    case 3:
352
      return dnnl::memory::format_tag::abc;
353
    case 4:
354
      return dnnl::memory::format_tag::abcd;
355
    case 5:
356
      return dnnl::memory::format_tag::abcde;
357
    case 6:
358
      return dnnl::memory::format_tag::abcdef;
359
    case 7:
360
      return dnnl::memory::format_tag::abcdefg;
361
    case 8:
362
      return dnnl::memory::format_tag::abcdefgh;
363
    case 9:
364
      return dnnl::memory::format_tag::abcdefghi;
365 366 367 368 369 370 371 372
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Paddle support tensors with rank in range <1, 9>, but received "
          "tensor with rank: %d",
          tensor_rank));
  }
}

373 374
inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size,
                                              MKLDNNMemoryFormat data_format) {
375
  if (dims_size == 1) {
376
    return MKLDNNMemoryFormat::x;
377
  } else if (dims_size == 2) {
378
    return MKLDNNMemoryFormat::nc;
379
  } else if (dims_size == 3) {
380 381 382 383
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::ncw;
    } else if (data_format == MKLDNNMemoryFormat::nhwc) {
      return MKLDNNMemoryFormat::nwc;
384
    }
385
  } else if (dims_size == 4) {
386 387
    if (data_format == MKLDNNMemoryFormat::goihw) {
      return MKLDNNMemoryFormat::oihw;
388
    }
389
  } else if (dims_size == 5) {
390 391
    if (data_format == MKLDNNMemoryFormat::goidhw) {
      return MKLDNNMemoryFormat::oidhw;
392
    }
393 394 395 396
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::ncdhw;
    } else if (data_format == MKLDNNMemoryFormat::nhwc) {
      return MKLDNNMemoryFormat::ndhwc;
397
    }
398
  } else if (dims_size == 6) {
399 400 401
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::abcdef;
    }
402 403 404 405
  }
  return data_format;
}

406
inline MKLDNNMemoryFormat data_format_to_memory_format(
407 408 409
    const std::string& data_format) {
  switch (framework::StringToDataLayout(data_format)) {
    case framework::DataLayout::kNHWC:
410
      return MKLDNNMemoryFormat::nhwc;
411
    case framework::DataLayout::kNCHW:
412
      return MKLDNNMemoryFormat::nchw;
413
    default:
414
      return MKLDNNMemoryFormat::any;
415 416 417
  }
}

418
inline MKLDNNMemoryFormat StringToMKLDNNFormat(std::string* format) {
419 420 421
  std::transform(format->begin(), format->end(), format->begin(), ::tolower);

  if (!format->compare("nchw")) {
422
    return MKLDNNMemoryFormat::nchw;
423
  } else if (!format->compare("nchw16c")) {
424
    return MKLDNNMemoryFormat::nChw16c;
425
  } else if (!format->compare("nchw8c")) {
426
    return MKLDNNMemoryFormat::nChw8c;
427
  } else if (!format->compare("nhwc")) {
428
    return MKLDNNMemoryFormat::nhwc;
429
  } else {
430
    return MKLDNNMemoryFormat::any;
431 432 433
  }
}

A
Adam 已提交
434 435 436 437 438
inline std::string ThreadIDasStr(void) {
  return std::to_string(
      std::hash<std::thread::id>()(std::this_thread::get_id()));
}

439 440 441
template <typename T>
inline void AppendKey(std::string* key, const T& num) {
  key->append(std::to_string(num));
A
Adam 已提交
442 443
}

A
Adam 已提交
444 445
template <>
inline void AppendKey(std::string* key,
446
                      const dnnl::memory::format_tag& format) {
A
Adam 已提交
447 448 449 450 451
  key->append(std::to_string(static_cast<int>(format)));
}

template <>
inline void AppendKey(std::string* key,
452
                      const dnnl::memory::data_type& data_type) {
A
Adam 已提交
453 454 455 456
  key->append(std::to_string(static_cast<int>(data_type)));
}

template <>
457
inline void AppendKey(std::string* key, const dnnl::algorithm& algorithm) {
A
Adam 已提交
458 459 460 461 462
  key->append(std::to_string(static_cast<int>(algorithm)));
}

template <>
inline void AppendKey(std::string* key,
463
                      const dnnl::normalization_flags& flags) {
A
Adam 已提交
464 465 466
  key->append(std::to_string(static_cast<int>(flags)));
}

467 468
inline void AppendKey(std::string* key, const std::string& str) {
  key->append(str);
A
Adam 已提交
469 470
}

471
inline void AppendKey(std::string* key, const char* str) { key->append(str); }
A
Adam 已提交
472

A
Adam 已提交
473 474
template <typename T>
inline void AppendKey(std::string* key, const std::vector<T>& dims) {
475
  for (size_t i = 0; i < dims.size(); i++) {
A
Adam 已提交
476 477 478 479
    AppendKey(key, std::to_string(dims[i]));
  }
}

480 481 482 483
// If MKLDNN build and CPU place then register suffix in DeviceContext
inline void AttachPointerHashToMKLDNNKey(void* ptr,
                                         const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
J
Jacek Czaja 已提交
484 485 486 487 488 489 490 491 492 493 494 495 496
    // Static vars will remember first executor and its thread
    // so both of them need to be processed by the same thread within
    // critical section
    static std::mutex static_vars_barrier;
    static_vars_barrier.lock();
    static auto first_exec = ptr;
    static auto first_thread = ThreadIDasStr();
    static_vars_barrier.unlock();

    if (first_exec != ptr) {
      paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix(
          "E" + std::to_string(reinterpret_cast<uintptr_t>(ptr)));
    }
497 498 499
    // Let's register adress of current executor
    paddle::platform::MKLDNNDeviceContext::tls().set_curr_exec(ptr);

J
Jacek Czaja 已提交
500 501 502 503
    // For first thread
    if (first_thread == ThreadIDasStr()) {
      paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key();
    }
504 505 506
  }
}

507
template <typename... ArgTypes>
508 509
inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx,
                             ArgTypes&&... args) {
510
  std::string key;
511
  key.reserve(64);
512
  using expand_type = int[];
513
  expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...};
J
Jacek Czaja 已提交
514
  key += paddle::platform::MKLDNNDeviceContext::tls().get_key_suffix();
515 516 517
  return key;
}

518 519
inline std::string ExtendKeyWithThreadInfoIfNeeded(
    const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) {
J
Jacek Czaja 已提交
520 521
  return (paddle::platform::MKLDNNDeviceContext::tls().is_tid_used_in_key() ==
          true)
522 523 524 525
             ? key + "-t:" + ThreadIDasStr()
             : key;
}

A
Adam 已提交
526 527
inline std::vector<std::vector<int64_t>> ToMkldnnPadding(
    const std::vector<int64_t>& paddings) {
528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547
  if (paddings.size() == 6) {
    int padding_front = paddings[0];
    int padding_back = paddings[1];
    int padding_top = paddings[2];
    int padding_bottom = paddings[3];
    int padding_left = paddings[4];
    int padding_right = paddings[5];

    return {{padding_front, padding_top, padding_left},
            {padding_back, padding_bottom, padding_right}};
  } else {
    int padding_top = paddings[0];
    int padding_bottom = paddings[1];
    int padding_left = paddings[2];
    int padding_right = paddings[3];

    return {{padding_top, padding_left}, {padding_bottom, padding_right}};
  }
}

548 549 550 551 552 553 554 555 556 557 558 559 560
// The function adjusts the vector of weight dimensions for group convolutions
inline void GetGroupConvWeightsTz(std::vector<int64_t>& weights_tz,  // NOLINT
                                  const int groups) {
  if (groups > 1) {
    // if (is_conv3d) [o, i, d, h, w]->[g, o/g, i, d, h, w]
    // else [o, i, h, w] -> [g, o/g, i, h, w]
    weights_tz.push_back(0);
    std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end());
    weights_tz[0] = groups;
    weights_tz[1] = weights_tz[1] / groups;
  }
}

561 562 563 564 565
inline bool HasOpINT8DataType(const paddle::framework::OpDesc* op) {
  return (op->GetAttrIfExists<std::string>("mkldnn_data_type") == "int8" ||
          op->GetAttrIfExists<bool>("use_quantizer"));
}

566 567 568 569 570 571 572
inline bool HasOpBFLOAT16DataType(const paddle::framework::OpDesc* op) {
  return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "bfloat16";
}

inline bool HasOpFLOAT32DataType(const paddle::framework::OpDesc* op) {
  return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "float32";
}
A
Adam Osewski 已提交
573

A
Adam 已提交
574 575
enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP };

A
Adam Osewski 已提交
576 577 578 579 580
template <typename T>
bool constexpr is_int8() {
  return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}

T
tensor-tang 已提交
581 582
}  // namespace platform
}  // namespace paddle