mkldnn_helper.h 21.3 KB
Newer Older
1
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
T
tensor-tang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

16
#include <algorithm>
J
Jacek Czaja 已提交
17
#include <iostream>
P
Physher 已提交
18
#include <memory>
J
Jacek Czaja 已提交
19
#include <sstream>
G
gongweibao 已提交
20
#include <string>
21
#include <utility>
22
#include <vector>
23
#include "dnnl.hpp"
24
#include "paddle/fluid/framework/operator.h"
M
mozga-intel 已提交
25
#include "paddle/fluid/platform/place.h"
C
chenjian 已提交
26
#include "paddle/fluid/platform/profiler/event_tracing.h"
T
tensor-tang 已提交
27
namespace paddle {
28
#ifdef PADDLE_WITH_MKLDNN
29
using MKLDNNMemoryFormat = dnnl::memory::format_tag;
30
#endif
T
tensor-tang 已提交
31 32
namespace platform {

33 34 35 36 37 38
using MKLDNNStream = dnnl::stream;
using MKLDNNEngine = dnnl::engine;
using MKLDNNMemory = dnnl::memory;
using MKLDNNMemoryDescriptor = dnnl::memory::desc;
using MKLDNNPrimitive = dnnl::primitive;
using MKLDNNPrimitiveDesc = dnnl::handle<dnnl_primitive_desc_t>;
T
tensor-tang 已提交
39

40 41 42 43 44
typedef std::unique_ptr<MKLDNNStream> MKLDNNStreamPtr;
typedef std::unique_ptr<MKLDNNEngine> MKLDNNEnginePtr;
typedef std::unique_ptr<MKLDNNMemory> MKLDNNMemoryPtr;
typedef std::unique_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr;
typedef std::unique_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr;
T
tensor-tang 已提交
45

46 47 48 49 50
template <typename Type>
void* to_void_cast(const Type* t) {
  return static_cast<void*>(const_cast<Type*>(t));
}

K
Krzysztof Binias 已提交
51 52 53 54 55
template <typename Type>
void* to_void_reinterpret_cast(const Type* t) {
  return reinterpret_cast<void*>(const_cast<Type*>(t));
}

56 57 58 59 60 61 62 63 64
template <class Type>
using tf_desc = typename Type::desc;

template <class Type>
using tf_pd = typename Type::primitive_desc;

template <typename Type, typename Engine, typename... Args>
std::shared_ptr<tf_pd<Type>> MKLDNNFwdPrimitiveDesc(const Engine& e,
                                                    Args&&... args) {
65
  auto desc = tf_desc<Type>(dnnl::prop_kind::forward, (args)...);
66 67 68 69 70 71 72 73 74 75 76
  auto pd = new tf_pd<Type>(desc, e);
  return std::shared_ptr<tf_pd<Type>>(pd);
}

template <typename Type, typename Engine, typename Primitive, typename... Args>
tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
                                   Args&&... args) {
  auto desc = tf_desc<Type>(args...);
  return tf_pd<Type>(desc, e, p);
}

77 78 79
inline void MatchShapeToLayout(framework::Tensor* tensor_in,
                               framework::DataLayout from,
                               framework::DataLayout to) {
80 81 82
  // In these data layouts, channel dimension is either on 2nd position: nChw or
  // at last nhwC, so for dim==2 these layouts are the same and nothing should
  // be done. Similarly for dim==1 when you have just one possible combination.
83 84 85 86
  if (tensor_in->dims().size() < 3) {
    return;
  }

J
Jacek Czaja 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
  auto print_dims = [](const std::vector<int>& dims) {
    std::ostringstream oss;

    if (!dims.empty()) {
      oss << "[";
      // Convert all but the last element to avoid a trailing ","
      std::copy(dims.begin(), dims.end() - 1,
                std::ostream_iterator<int>(oss, ","));

      // Now add the last element with no delimiter
      oss << dims.back() << "]";
    }

    return oss.str();
  };

103 104 105
  switch (from) {
    case framework::DataLayout::kMKLDNN:
      if (to == framework::DataLayout::kNHWC) {
106
        auto dims = phi::vectorize<int>(tensor_in->dims());
107
        std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
108
        tensor_in->Resize(phi::make_ddim(dims));
J
Jacek Czaja 已提交
109 110
        VLOG(3) << "Rotating Shape from: kMKLDNN to: kNHWC output_shape"
                << print_dims(dims);
111 112 113 114
      }
      break;
    case framework::DataLayout::kNHWC:
      if (to == framework::DataLayout::kMKLDNN) {
115
        auto dims = phi::vectorize<int>(tensor_in->dims());
116
        std::rotate(dims.begin() + 1, dims.end() - 1, dims.end());
117
        tensor_in->Resize(phi::make_ddim(dims));
J
Jacek Czaja 已提交
118 119
        VLOG(3) << "Rotating Shape from: kNHWC to: kMKLDNN output_shape"
                << print_dims(dims);
120 121 122 123 124 125 126
      }
      break;
    default:
      break;
  }
}

127 128 129 130 131
struct mkldnn_dummy_primitive {
  struct primitive_desc {};
  struct desc {};
};

132 133 134 135
inline dnnl::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
                                        dnnl::memory::data_type data_type,
                                        MKLDNNMemoryFormat format) {
  return dnnl::memory::desc({dims}, data_type, format);
136 137
}

138 139
inline void ClearMKLDNNCache(const platform::Place& place,
                             void* ptr = nullptr) {
140 141 142 143 144
  // Clear mkl-dnn cache,
  if (platform::is_cpu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::MKLDNNDeviceContext* dev_ctx =
        (platform::MKLDNNDeviceContext*)pool.Get(place);
145
    dev_ctx->ResetBlobMap(ptr);
146 147 148 149 150
    platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
        paddle::framework::DataLayout::kNCHW);
  }
}

151 152 153 154 155 156 157 158 159 160
inline void DontClearMKLDNNCache(const platform::Place& place) {
  // Clear mkl-dnn cache,
  if (platform::is_cpu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::MKLDNNDeviceContext* dev_ctx =
        (platform::MKLDNNDeviceContext*)pool.Get(place);
    dev_ctx->BlockNextCacheClearing();
  }
}

161
template <typename Type>
162 163
dnnl::memory::data_type MKLDNNGetDataType() {
  return dnnl::memory::data_type::undef;
164 165 166
}

template <>
167 168
inline dnnl::memory::data_type MKLDNNGetDataType<float>() {
  return dnnl::memory::data_type::f32;
169 170
}
template <>
171 172
inline dnnl::memory::data_type MKLDNNGetDataType<int32_t>() {
  return dnnl::memory::data_type::s32;
173
}
P
Physher 已提交
174
template <>
175 176
inline dnnl::memory::data_type MKLDNNGetDataType<int8_t>() {
  return dnnl::memory::data_type::s8;
P
Physher 已提交
177 178
}
template <>
179 180
inline dnnl::memory::data_type MKLDNNGetDataType<uint8_t>() {
  return dnnl::memory::data_type::u8;
P
Physher 已提交
181 182
}

183
template <>
184 185
inline dnnl::memory::data_type MKLDNNGetDataType<paddle::platform::bfloat16>() {
  return dnnl::memory::data_type::bf16;
186 187
}

188 189 190
inline void Reorder(dnnl::memory src, dnnl::memory dst,
                    const dnnl::engine& engine) {
  auto reorder_prim = dnnl::reorder(src, dst);
191
  auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
192
  platform::RecordEvent record_reorder("int_reorder",
C
chenjian 已提交
193 194
                                       platform::TracerEventType::UserDefined,
                                       2, platform::EventRole::kUniqueOp);
A
Adam 已提交
195 196
  reorder_prim.execute(astream, src, dst);
  astream.wait();
M
mozga-intel 已提交
197 198
}

199
inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
A
Adam 已提交
200 201 202 203 204 205 206
  auto ndims = mem_desc.data.ndims;
  auto strides = mem_desc.data.format_desc.blocking.strides;
  auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks;
  auto inner_blks = mem_desc.data.format_desc.blocking.inner_blks;
  auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs;

  if (ndims == 1) {
207
    return dnnl::memory::format_tag::x;
A
Adam 已提交
208 209 210
  } else if (ndims == 2) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1]) {
211
        return dnnl::memory::format_tag::nc;
A
Adam 已提交
212
      } else {
213
        return dnnl::memory::format_tag::cn;
A
Adam 已提交
214 215 216 217 218
      }
    }
  } else if (ndims == 3) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
219
        return dnnl::memory::format_tag::ncw;
A
Adam 已提交
220
      } else if (strides[1] >= strides[0] && strides[0] >= strides[2]) {
221
        return dnnl::memory::format_tag::ntc;
A
Adam 已提交
222
      } else {
223
        return dnnl::memory::format_tag::nwc;
A
Adam 已提交
224 225 226 227 228 229
      }
    }
  } else if (ndims == 4) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3]) {
230
        return dnnl::memory::format_tag::nchw;
231 232
      } else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
                 strides[1] >= strides[0]) {
233
        return dnnl::memory::format_tag::cdba;
234 235 236
      } else if (strides[3] >= strides[2] && strides[2] >= strides[0] &&
                 strides[0] >= strides[1]) {
        return dnnl::memory::format_tag::dcab;
A
Adam 已提交
237
      } else {
238
        return dnnl::memory::format_tag::nhwc;
A
Adam 已提交
239 240 241
      }
    } else if (inner_nblks == 1) {
      if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
242
        return dnnl::memory::format_tag::nChw16c;
A
Adam 已提交
243
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
244
        return dnnl::memory::format_tag::nChw8c;
A
Adam 已提交
245 246 247
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[1]) {
248
          return dnnl::memory::format_tag::Acdb8a;
A
Adam 已提交
249 250
        }
      } else if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
251
        return dnnl::memory::format_tag::nChw4c;
A
Adam 已提交
252 253 254
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[1]) {
255
          return dnnl::memory::format_tag::Acdb16a;
A
Adam 已提交
256 257 258 259 260
        }
      }
    } else if (inner_nblks == 2) {
      if (inner_blks[0] == 16 && inner_blks[1] == 16) {
        if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
261
          return dnnl::memory::format_tag::OIhw16i16o;
A
Adam 已提交
262 263 264
        }
      } else if (inner_blks[0] == 8 && inner_blks[1] == 8) {
        if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
265
          return dnnl::memory::format_tag::OIhw8i8o;
A
Adam 已提交
266 267 268 269 270 271 272
        }
      }
    }
  } else if (ndims == 5) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3] && strides[3] >= strides[4]) {
273 274 275 276 277 278 279
        return dnnl::memory::format_tag::abcde;
      } else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
                 strides[1] >= strides[3] && strides[3] >= strides[4]) {
        return dnnl::memory::format_tag::acbde;
      } else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
                 strides[3] >= strides[4] && strides[4] >= strides[1]) {
        return dnnl::memory::format_tag::acdeb;
A
Adam 已提交
280 281 282 283 284
      }
    } else if (inner_nblks == 1) {
      if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[4] && strides[4] >= strides[1]) {
285
          return dnnl::memory::format_tag::Acdeb8a;
A
Adam 已提交
286
        }
287 288
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
289
          return dnnl::memory::format_tag::Abcde8a;
290
        }
A
Adam 已提交
291 292 293
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
294
          return dnnl::memory::format_tag::aBcde8b;
A
Adam 已提交
295 296 297 298
        }
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[4] && strides[4] >= strides[1]) {
299
          return dnnl::memory::format_tag::Acdeb16a;
A
Adam 已提交
300
        }
301 302
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
303
          return dnnl::memory::format_tag::Abcde16a;
304
        }
A
Adam 已提交
305 306 307
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
308
          return dnnl::memory::format_tag::aBcde16b;
A
Adam 已提交
309 310 311 312 313 314 315 316
        }
      }
    }
  } else if (ndims == 6) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3] && strides[3] >= strides[4] &&
          strides[4] >= strides[5]) {
317
        return dnnl::memory::format_tag::abcdef;
318 319 320 321
      } else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
                 strides[1] >= strides[3] && strides[3] >= strides[4] &&
                 strides[4] >= strides[5]) {
        return dnnl::memory::format_tag::acbdef;
A
Adam 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
      }
    }
  }
  // DEBUG CODE - KEEP UNTILL TENSOR.MEMORY_DESC IMPLEMENTED
  // std::cout<<"@@@@@@@@@@ UNDEFINED FORMAT @@@@@@@@@@@@@@@@@@@"<<std::endl;
  // std::cout<<"NDIMS: "<<ndims<<std::endl;
  // std::cout<<"INNER_NBLKS: "<<inner_nblks<<std::endl;
  // for (int i=0;i<ndims;++i) {
  //   std::cout<<"STRIDE["<<i<<"]: "<<strides[i]<<std::endl;
  // }
  // for (int i=0;i<inner_nblks;++i) {
  //   std::cout<<"INNER_BLKS["<<i<<"]: "<<inner_blks[i]<<std::endl;
  // }
  // for (int i=0;i<inner_nblks;++i) {
  //   std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl;
  // }
338
  return dnnl::memory::format_tag::undef;
M
mozga-intel 已提交
339 340
}

341
inline dnnl::memory::format_tag GetMKLDNNFormat(const dnnl::memory memory) {
A
Adam 已提交
342 343
  auto mem_desc = memory.get_desc();
  return GetMKLDNNFormat(mem_desc);
344 345
}

346
inline dnnl::memory::format_tag GetPlainMKLDNNFormat(int tensor_rank) {
347 348
  switch (tensor_rank) {
    case 1:
349
      return dnnl::memory::format_tag::a;
350
    case 2:
351
      return dnnl::memory::format_tag::ab;
352
    case 3:
353
      return dnnl::memory::format_tag::abc;
354
    case 4:
355
      return dnnl::memory::format_tag::abcd;
356
    case 5:
357
      return dnnl::memory::format_tag::abcde;
358
    case 6:
359
      return dnnl::memory::format_tag::abcdef;
360
    case 7:
361
      return dnnl::memory::format_tag::abcdefg;
362
    case 8:
363
      return dnnl::memory::format_tag::abcdefgh;
364
    case 9:
365
      return dnnl::memory::format_tag::abcdefghi;
366 367 368 369 370 371 372 373
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Paddle support tensors with rank in range <1, 9>, but received "
          "tensor with rank: %d",
          tensor_rank));
  }
}

374 375
inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size,
                                              MKLDNNMemoryFormat data_format) {
376
  if (dims_size == 1) {
377
    return MKLDNNMemoryFormat::x;
378
  } else if (dims_size == 2) {
379
    return MKLDNNMemoryFormat::nc;
380
  } else if (dims_size == 3) {
381 382 383 384
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::ncw;
    } else if (data_format == MKLDNNMemoryFormat::nhwc) {
      return MKLDNNMemoryFormat::nwc;
385
    }
386
  } else if (dims_size == 4) {
387 388
    if (data_format == MKLDNNMemoryFormat::goihw) {
      return MKLDNNMemoryFormat::oihw;
389
    }
390
  } else if (dims_size == 5) {
391 392
    if (data_format == MKLDNNMemoryFormat::goidhw) {
      return MKLDNNMemoryFormat::oidhw;
393
    }
394 395 396 397
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::ncdhw;
    } else if (data_format == MKLDNNMemoryFormat::nhwc) {
      return MKLDNNMemoryFormat::ndhwc;
398
    }
399
  } else if (dims_size == 6) {
400 401 402
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::abcdef;
    }
403 404 405 406
  }
  return data_format;
}

407
inline MKLDNNMemoryFormat data_format_to_memory_format(
408 409 410
    const std::string& data_format) {
  switch (framework::StringToDataLayout(data_format)) {
    case framework::DataLayout::kNHWC:
411
      return MKLDNNMemoryFormat::nhwc;
412
    case framework::DataLayout::kNCHW:
413
      return MKLDNNMemoryFormat::nchw;
414
    default:
415
      return MKLDNNMemoryFormat::any;
416 417 418
  }
}

419
inline MKLDNNMemoryFormat StringToMKLDNNFormat(std::string* format) {
420 421 422
  std::transform(format->begin(), format->end(), format->begin(), ::tolower);

  if (!format->compare("nchw")) {
423
    return MKLDNNMemoryFormat::nchw;
424
  } else if (!format->compare("nchw16c")) {
425
    return MKLDNNMemoryFormat::nChw16c;
426
  } else if (!format->compare("nchw8c")) {
427
    return MKLDNNMemoryFormat::nChw8c;
428
  } else if (!format->compare("nhwc")) {
429
    return MKLDNNMemoryFormat::nhwc;
430
  } else {
431
    return MKLDNNMemoryFormat::any;
432 433 434
  }
}

A
Adam 已提交
435 436 437 438 439
inline std::string ThreadIDasStr(void) {
  return std::to_string(
      std::hash<std::thread::id>()(std::this_thread::get_id()));
}

440 441 442
template <typename T>
inline void AppendKey(std::string* key, const T& num) {
  key->append(std::to_string(num));
A
Adam 已提交
443 444
}

A
Adam 已提交
445 446
template <>
inline void AppendKey(std::string* key,
447
                      const dnnl::memory::format_tag& format) {
A
Adam 已提交
448 449 450 451 452
  key->append(std::to_string(static_cast<int>(format)));
}

template <>
inline void AppendKey(std::string* key,
453
                      const dnnl::memory::data_type& data_type) {
A
Adam 已提交
454 455 456 457
  key->append(std::to_string(static_cast<int>(data_type)));
}

template <>
458
inline void AppendKey(std::string* key, const dnnl::algorithm& algorithm) {
A
Adam 已提交
459 460 461 462 463
  key->append(std::to_string(static_cast<int>(algorithm)));
}

template <>
inline void AppendKey(std::string* key,
464
                      const dnnl::normalization_flags& flags) {
A
Adam 已提交
465 466 467
  key->append(std::to_string(static_cast<int>(flags)));
}

468 469
inline void AppendKey(std::string* key, const std::string& str) {
  key->append(str);
A
Adam 已提交
470 471
}

472
inline void AppendKey(std::string* key, const char* str) { key->append(str); }
A
Adam 已提交
473

A
Adam 已提交
474 475
template <typename T>
inline void AppendKey(std::string* key, const std::vector<T>& dims) {
476
  for (size_t i = 0; i < dims.size(); i++) {
A
Adam 已提交
477 478 479 480
    AppendKey(key, std::to_string(dims[i]));
  }
}

481 482 483 484
// If MKLDNN build and CPU place then register suffix in DeviceContext
inline void AttachPointerHashToMKLDNNKey(void* ptr,
                                         const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
J
Jacek Czaja 已提交
485 486 487 488 489 490 491 492 493 494 495 496 497
    // Static vars will remember first executor and its thread
    // so both of them need to be processed by the same thread within
    // critical section
    static std::mutex static_vars_barrier;
    static_vars_barrier.lock();
    static auto first_exec = ptr;
    static auto first_thread = ThreadIDasStr();
    static_vars_barrier.unlock();

    if (first_exec != ptr) {
      paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix(
          "E" + std::to_string(reinterpret_cast<uintptr_t>(ptr)));
    }
498 499 500
    // Let's register adress of current executor
    paddle::platform::MKLDNNDeviceContext::tls().set_curr_exec(ptr);

J
Jacek Czaja 已提交
501 502 503 504
    // For first thread
    if (first_thread == ThreadIDasStr()) {
      paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key();
    }
505 506 507
  }
}

508
template <typename... ArgTypes>
509 510
inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx,
                             ArgTypes&&... args) {
511
  std::string key;
512
  key.reserve(64);
513
  using expand_type = int[];
514
  expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...};
J
Jacek Czaja 已提交
515
  key += paddle::platform::MKLDNNDeviceContext::tls().get_key_suffix();
516 517 518
  return key;
}

519 520
inline std::string ExtendKeyWithThreadInfoIfNeeded(
    const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) {
J
Jacek Czaja 已提交
521 522
  return (paddle::platform::MKLDNNDeviceContext::tls().is_tid_used_in_key() ==
          true)
523 524 525 526
             ? key + "-t:" + ThreadIDasStr()
             : key;
}

A
Adam 已提交
527 528
inline std::vector<std::vector<int64_t>> ToMkldnnPadding(
    const std::vector<int64_t>& paddings) {
529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548
  if (paddings.size() == 6) {
    int padding_front = paddings[0];
    int padding_back = paddings[1];
    int padding_top = paddings[2];
    int padding_bottom = paddings[3];
    int padding_left = paddings[4];
    int padding_right = paddings[5];

    return {{padding_front, padding_top, padding_left},
            {padding_back, padding_bottom, padding_right}};
  } else {
    int padding_top = paddings[0];
    int padding_bottom = paddings[1];
    int padding_left = paddings[2];
    int padding_right = paddings[3];

    return {{padding_top, padding_left}, {padding_bottom, padding_right}};
  }
}

549 550 551 552 553 554 555 556 557 558 559 560 561
// The function adjusts the vector of weight dimensions for group convolutions
inline void GetGroupConvWeightsTz(std::vector<int64_t>& weights_tz,  // NOLINT
                                  const int groups) {
  if (groups > 1) {
    // if (is_conv3d) [o, i, d, h, w]->[g, o/g, i, d, h, w]
    // else [o, i, h, w] -> [g, o/g, i, h, w]
    weights_tz.push_back(0);
    std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end());
    weights_tz[0] = groups;
    weights_tz[1] = weights_tz[1] / groups;
  }
}

J
Jacek Czaja 已提交
562 563 564 565
inline void RegisterModelLayout(
    std::vector<std::unique_ptr<framework::OperatorBase>>& ops,
    const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
L
Leo Chen 已提交
566
    VLOG(4) << "RegisterModelLayout for mkldnn";
J
Jacek Czaja 已提交
567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590
    auto check_attrib = [](std::unique_ptr<framework::OperatorBase>& op,
                           const std::string& attrib_name) -> bool {
      if (op->HasAttr(attrib_name)) {
        auto data_format = op->Attr<std::string>(attrib_name);
        platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
            data_format.compare("NHWC") == 0 ? framework::DataLayout::kNHWC
                                             : framework::DataLayout::kNCHW);
        return true;
      } else {
        return false;
      }
    };

    for (auto& op : ops) {
      if (check_attrib(op, std::string("data_format"))) {
        return;
      }
      if (check_attrib(op, std::string("data_layout"))) {
        return;
      }
    }
  }
}

591 592 593 594 595
inline bool HasOpINT8DataType(const paddle::framework::OpDesc* op) {
  return (op->GetAttrIfExists<std::string>("mkldnn_data_type") == "int8" ||
          op->GetAttrIfExists<bool>("use_quantizer"));
}

596 597 598 599 600 601 602
inline bool HasOpBFLOAT16DataType(const paddle::framework::OpDesc* op) {
  return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "bfloat16";
}

inline bool HasOpFLOAT32DataType(const paddle::framework::OpDesc* op) {
  return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "float32";
}
A
Adam Osewski 已提交
603

A
Adam 已提交
604 605
enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP };

A
Adam Osewski 已提交
606 607 608 609 610
template <typename T>
bool constexpr is_int8() {
  return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}

T
tensor-tang 已提交
611 612
}  // namespace platform
}  // namespace paddle