mkldnn_helper.h 22.0 KB
Newer Older
1
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
T
tensor-tang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

16
#include <algorithm>
J
Jacek Czaja 已提交
17
#include <iostream>
P
Physher 已提交
18
#include <memory>
J
Jacek Czaja 已提交
19
#include <sstream>
G
gongweibao 已提交
20
#include <string>
21
#include <utility>
22
#include <vector>
23

24
#include "dnnl.hpp"
25
#include "paddle/fluid/framework/operator.h"
M
mozga-intel 已提交
26
#include "paddle/fluid/platform/place.h"
C
chenjian 已提交
27
#include "paddle/fluid/platform/profiler/event_tracing.h"
T
tensor-tang 已提交
28
namespace paddle {
29
#ifdef PADDLE_WITH_MKLDNN
30
using MKLDNNMemoryFormat = dnnl::memory::format_tag;
31
#endif
T
tensor-tang 已提交
32 33
namespace platform {

34 35 36 37 38 39
using MKLDNNStream = dnnl::stream;
using MKLDNNEngine = dnnl::engine;
using MKLDNNMemory = dnnl::memory;
using MKLDNNMemoryDescriptor = dnnl::memory::desc;
using MKLDNNPrimitive = dnnl::primitive;
using MKLDNNPrimitiveDesc = dnnl::handle<dnnl_primitive_desc_t>;
T
tensor-tang 已提交
40

41 42 43 44 45
typedef std::unique_ptr<MKLDNNStream> MKLDNNStreamPtr;
typedef std::unique_ptr<MKLDNNEngine> MKLDNNEnginePtr;
typedef std::unique_ptr<MKLDNNMemory> MKLDNNMemoryPtr;
typedef std::unique_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr;
typedef std::unique_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr;
T
tensor-tang 已提交
46

47 48 49 50 51
template <typename Type>
void* to_void_cast(const Type* t) {
  return static_cast<void*>(const_cast<Type*>(t));
}

K
Krzysztof Binias 已提交
52 53 54 55 56
template <typename Type>
void* to_void_reinterpret_cast(const Type* t) {
  return reinterpret_cast<void*>(const_cast<Type*>(t));
}

57 58 59 60 61 62 63 64 65
template <class Type>
using tf_desc = typename Type::desc;

template <class Type>
using tf_pd = typename Type::primitive_desc;

template <typename Type, typename Engine, typename... Args>
std::shared_ptr<tf_pd<Type>> MKLDNNFwdPrimitiveDesc(const Engine& e,
                                                    Args&&... args) {
66
  auto desc = tf_desc<Type>(dnnl::prop_kind::forward, (args)...);
67 68 69 70 71 72 73 74 75 76 77
  auto pd = new tf_pd<Type>(desc, e);
  return std::shared_ptr<tf_pd<Type>>(pd);
}

template <typename Type, typename Engine, typename Primitive, typename... Args>
tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
                                   Args&&... args) {
  auto desc = tf_desc<Type>(args...);
  return tf_pd<Type>(desc, e, p);
}

78 79 80
inline void MatchShapeToLayout(framework::Tensor* tensor_in,
                               framework::DataLayout from,
                               framework::DataLayout to) {
J
Jacek Czaja 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
  auto print_dims = [](const std::vector<int>& dims) {
    std::ostringstream oss;

    if (!dims.empty()) {
      oss << "[";
      // Convert all but the last element to avoid a trailing ","
      std::copy(dims.begin(), dims.end() - 1,
                std::ostream_iterator<int>(oss, ","));

      // Now add the last element with no delimiter
      oss << dims.back() << "]";
    }

    return oss.str();
  };

97 98 99 100 101 102 103 104 105
  // In these data layouts, channel dimension is either on 2nd position: nChw or
  // at last nhwC, so for dim==2 these layouts are the same and nothing should
  // be done. Similarly for dim==1 when you have just one possible combination.
  if (tensor_in->dims().size() < 3) {
    VLOG(3) << "Keeping kMKLDNN/kNHWC/kNDHWC output_shape"
            << print_dims(phi::vectorize<int>(tensor_in->dims()));
    return;
  }

106 107
  switch (from) {
    case framework::DataLayout::kMKLDNN:
108 109
      if ((to == framework::DataLayout::kNHWC) ||
          (to == framework::DataLayout::kNDHWC)) {
110
        auto dims = phi::vectorize<int>(tensor_in->dims());
111
        std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
112
        tensor_in->Resize(phi::make_ddim(dims));
113
        VLOG(3) << "Rotating Shape from: kMKLDNN to: kNHWC/kNDHWC output_shape"
J
Jacek Czaja 已提交
114
                << print_dims(dims);
115 116 117
      }
      break;
    case framework::DataLayout::kNHWC:
118
    case framework::DataLayout::kNDHWC:
119
      if (to == framework::DataLayout::kMKLDNN) {
120
        auto dims = phi::vectorize<int>(tensor_in->dims());
121
        std::rotate(dims.begin() + 1, dims.end() - 1, dims.end());
122
        tensor_in->Resize(phi::make_ddim(dims));
123
        VLOG(3) << "Rotating Shape from: kNHWC/kNDHWC to: kMKLDNN output_shape"
J
Jacek Czaja 已提交
124
                << print_dims(dims);
125 126 127 128 129 130 131
      }
      break;
    default:
      break;
  }
}

132 133 134 135 136
struct mkldnn_dummy_primitive {
  struct primitive_desc {};
  struct desc {};
};

137 138 139 140
inline dnnl::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
                                        dnnl::memory::data_type data_type,
                                        MKLDNNMemoryFormat format) {
  return dnnl::memory::desc({dims}, data_type, format);
141 142
}

143 144
inline void ClearMKLDNNCache(const platform::Place& place,
                             void* ptr = nullptr) {
145 146 147 148 149
  // Clear mkl-dnn cache,
  if (platform::is_cpu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::MKLDNNDeviceContext* dev_ctx =
        (platform::MKLDNNDeviceContext*)pool.Get(place);
150
    dev_ctx->ResetBlobMap(ptr);
151 152 153 154 155
    platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
        paddle::framework::DataLayout::kNCHW);
  }
}

156 157 158 159 160 161 162 163 164 165
inline void DontClearMKLDNNCache(const platform::Place& place) {
  // Clear mkl-dnn cache,
  if (platform::is_cpu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::MKLDNNDeviceContext* dev_ctx =
        (platform::MKLDNNDeviceContext*)pool.Get(place);
    dev_ctx->BlockNextCacheClearing();
  }
}

166
template <typename Type>
167 168
dnnl::memory::data_type MKLDNNGetDataType() {
  return dnnl::memory::data_type::undef;
169 170 171
}

template <>
172 173
inline dnnl::memory::data_type MKLDNNGetDataType<float>() {
  return dnnl::memory::data_type::f32;
174 175
}
template <>
176 177
inline dnnl::memory::data_type MKLDNNGetDataType<int32_t>() {
  return dnnl::memory::data_type::s32;
178
}
P
Physher 已提交
179
template <>
180 181
inline dnnl::memory::data_type MKLDNNGetDataType<int8_t>() {
  return dnnl::memory::data_type::s8;
P
Physher 已提交
182 183
}
template <>
184 185
inline dnnl::memory::data_type MKLDNNGetDataType<uint8_t>() {
  return dnnl::memory::data_type::u8;
P
Physher 已提交
186 187
}

188
template <>
189 190
inline dnnl::memory::data_type MKLDNNGetDataType<paddle::platform::bfloat16>() {
  return dnnl::memory::data_type::bf16;
191 192
}

193 194 195
inline void Reorder(dnnl::memory src, dnnl::memory dst,
                    const dnnl::engine& engine) {
  auto reorder_prim = dnnl::reorder(src, dst);
196
  auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
197
  platform::RecordEvent record_reorder("int_reorder",
C
chenjian 已提交
198 199
                                       platform::TracerEventType::UserDefined,
                                       2, platform::EventRole::kUniqueOp);
A
Adam 已提交
200 201
  reorder_prim.execute(astream, src, dst);
  astream.wait();
M
mozga-intel 已提交
202 203
}

204
inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
A
Adam 已提交
205 206 207 208 209 210 211
  auto ndims = mem_desc.data.ndims;
  auto strides = mem_desc.data.format_desc.blocking.strides;
  auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks;
  auto inner_blks = mem_desc.data.format_desc.blocking.inner_blks;
  auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs;

  if (ndims == 1) {
212
    return dnnl::memory::format_tag::x;
A
Adam 已提交
213 214 215
  } else if (ndims == 2) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1]) {
216
        return dnnl::memory::format_tag::nc;
A
Adam 已提交
217
      } else {
218
        return dnnl::memory::format_tag::cn;
A
Adam 已提交
219 220 221 222 223
      }
    }
  } else if (ndims == 3) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
224
        return dnnl::memory::format_tag::ncw;
A
Adam 已提交
225
      } else if (strides[1] >= strides[0] && strides[0] >= strides[2]) {
226
        return dnnl::memory::format_tag::ntc;
A
Adam 已提交
227
      } else {
228
        return dnnl::memory::format_tag::nwc;
A
Adam 已提交
229 230 231 232 233 234
      }
    }
  } else if (ndims == 4) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3]) {
235
        return dnnl::memory::format_tag::nchw;
236 237
      } else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
                 strides[1] >= strides[0]) {
238
        return dnnl::memory::format_tag::cdba;
239 240 241
      } else if (strides[3] >= strides[2] && strides[2] >= strides[0] &&
                 strides[0] >= strides[1]) {
        return dnnl::memory::format_tag::dcab;
A
Adam 已提交
242
      } else {
243
        return dnnl::memory::format_tag::nhwc;
A
Adam 已提交
244 245 246
      }
    } else if (inner_nblks == 1) {
      if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
247
        return dnnl::memory::format_tag::nChw16c;
A
Adam 已提交
248
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
249
        return dnnl::memory::format_tag::nChw8c;
A
Adam 已提交
250 251 252
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[1]) {
253
          return dnnl::memory::format_tag::Acdb8a;
A
Adam 已提交
254 255
        }
      } else if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
256
        return dnnl::memory::format_tag::nChw4c;
A
Adam 已提交
257 258 259
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[1]) {
260
          return dnnl::memory::format_tag::Acdb16a;
A
Adam 已提交
261 262 263 264 265
        }
      }
    } else if (inner_nblks == 2) {
      if (inner_blks[0] == 16 && inner_blks[1] == 16) {
        if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
266
          return dnnl::memory::format_tag::OIhw16i16o;
A
Adam 已提交
267 268 269
        }
      } else if (inner_blks[0] == 8 && inner_blks[1] == 8) {
        if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
270
          return dnnl::memory::format_tag::OIhw8i8o;
A
Adam 已提交
271 272 273 274 275 276 277
        }
      }
    }
  } else if (ndims == 5) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3] && strides[3] >= strides[4]) {
278 279 280 281 282 283 284
        return dnnl::memory::format_tag::abcde;
      } else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
                 strides[1] >= strides[3] && strides[3] >= strides[4]) {
        return dnnl::memory::format_tag::acbde;
      } else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
                 strides[3] >= strides[4] && strides[4] >= strides[1]) {
        return dnnl::memory::format_tag::acdeb;
A
Adam 已提交
285 286
      }
    } else if (inner_nblks == 1) {
287 288 289 290 291 292
      if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
          return dnnl::memory::format_tag::aBcde4b;
        }
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
A
Adam 已提交
293 294
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[4] && strides[4] >= strides[1]) {
295
          return dnnl::memory::format_tag::Acdeb8a;
A
Adam 已提交
296
        }
297 298
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
299
          return dnnl::memory::format_tag::Abcde8a;
300
        }
A
Adam 已提交
301 302 303
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
304
          return dnnl::memory::format_tag::aBcde8b;
A
Adam 已提交
305 306 307 308
        }
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[4] && strides[4] >= strides[1]) {
309
          return dnnl::memory::format_tag::Acdeb16a;
A
Adam 已提交
310
        }
311 312
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
313
          return dnnl::memory::format_tag::Abcde16a;
314
        }
A
Adam 已提交
315 316 317
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
318
          return dnnl::memory::format_tag::aBcde16b;
A
Adam 已提交
319 320 321 322 323 324 325 326
        }
      }
    }
  } else if (ndims == 6) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3] && strides[3] >= strides[4] &&
          strides[4] >= strides[5]) {
327
        return dnnl::memory::format_tag::abcdef;
328 329 330 331
      } else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
                 strides[1] >= strides[3] && strides[3] >= strides[4] &&
                 strides[4] >= strides[5]) {
        return dnnl::memory::format_tag::acbdef;
A
Adam 已提交
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
      }
    }
  }
  // DEBUG CODE - KEEP UNTILL TENSOR.MEMORY_DESC IMPLEMENTED
  // std::cout<<"@@@@@@@@@@ UNDEFINED FORMAT @@@@@@@@@@@@@@@@@@@"<<std::endl;
  // std::cout<<"NDIMS: "<<ndims<<std::endl;
  // std::cout<<"INNER_NBLKS: "<<inner_nblks<<std::endl;
  // for (int i=0;i<ndims;++i) {
  //   std::cout<<"STRIDE["<<i<<"]: "<<strides[i]<<std::endl;
  // }
  // for (int i=0;i<inner_nblks;++i) {
  //   std::cout<<"INNER_BLKS["<<i<<"]: "<<inner_blks[i]<<std::endl;
  // }
  // for (int i=0;i<inner_nblks;++i) {
  //   std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl;
  // }
348
  return dnnl::memory::format_tag::undef;
M
mozga-intel 已提交
349 350
}

351
inline dnnl::memory::format_tag GetMKLDNNFormat(const dnnl::memory memory) {
A
Adam 已提交
352 353
  auto mem_desc = memory.get_desc();
  return GetMKLDNNFormat(mem_desc);
354 355
}

356
inline dnnl::memory::format_tag GetPlainMKLDNNFormat(int tensor_rank) {
357 358
  switch (tensor_rank) {
    case 1:
359
      return dnnl::memory::format_tag::a;
360
    case 2:
361
      return dnnl::memory::format_tag::ab;
362
    case 3:
363
      return dnnl::memory::format_tag::abc;
364
    case 4:
365
      return dnnl::memory::format_tag::abcd;
366
    case 5:
367
      return dnnl::memory::format_tag::abcde;
368
    case 6:
369
      return dnnl::memory::format_tag::abcdef;
370
    case 7:
371
      return dnnl::memory::format_tag::abcdefg;
372
    case 8:
373
      return dnnl::memory::format_tag::abcdefgh;
374
    case 9:
375
      return dnnl::memory::format_tag::abcdefghi;
376 377 378 379 380 381 382 383
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Paddle support tensors with rank in range <1, 9>, but received "
          "tensor with rank: %d",
          tensor_rank));
  }
}

384 385
inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size,
                                              MKLDNNMemoryFormat data_format) {
386
  if (dims_size == 1) {
387
    return MKLDNNMemoryFormat::x;
388
  } else if (dims_size == 2) {
389
    return MKLDNNMemoryFormat::nc;
390
  } else if (dims_size == 3) {
391 392 393 394
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::ncw;
    } else if (data_format == MKLDNNMemoryFormat::nhwc) {
      return MKLDNNMemoryFormat::nwc;
395
    }
396
  } else if (dims_size == 4) {
397 398
    if (data_format == MKLDNNMemoryFormat::goihw) {
      return MKLDNNMemoryFormat::oihw;
399
    }
400
  } else if (dims_size == 5) {
401 402
    if (data_format == MKLDNNMemoryFormat::goidhw) {
      return MKLDNNMemoryFormat::oidhw;
403
    }
404 405 406 407
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::ncdhw;
    } else if (data_format == MKLDNNMemoryFormat::nhwc) {
      return MKLDNNMemoryFormat::ndhwc;
408
    }
409
  } else if (dims_size == 6) {
410 411 412
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::abcdef;
    }
413 414 415 416
  }
  return data_format;
}

417
inline MKLDNNMemoryFormat data_format_to_memory_format(
418 419 420
    const std::string& data_format) {
  switch (framework::StringToDataLayout(data_format)) {
    case framework::DataLayout::kNHWC:
421
      return MKLDNNMemoryFormat::nhwc;
422
    case framework::DataLayout::kNCHW:
423
      return MKLDNNMemoryFormat::nchw;
424
    default:
425
      return MKLDNNMemoryFormat::any;
426 427 428
  }
}

429
inline MKLDNNMemoryFormat StringToMKLDNNFormat(std::string* format) {
430 431 432
  std::transform(format->begin(), format->end(), format->begin(), ::tolower);

  if (!format->compare("nchw")) {
433
    return MKLDNNMemoryFormat::nchw;
434
  } else if (!format->compare("nchw16c")) {
435
    return MKLDNNMemoryFormat::nChw16c;
436
  } else if (!format->compare("nchw8c")) {
437
    return MKLDNNMemoryFormat::nChw8c;
438
  } else if (!format->compare("nhwc")) {
439
    return MKLDNNMemoryFormat::nhwc;
440
  } else {
441
    return MKLDNNMemoryFormat::any;
442 443 444
  }
}

A
Adam 已提交
445 446 447 448 449
inline std::string ThreadIDasStr(void) {
  return std::to_string(
      std::hash<std::thread::id>()(std::this_thread::get_id()));
}

450 451 452
template <typename T>
inline void AppendKey(std::string* key, const T& num) {
  key->append(std::to_string(num));
A
Adam 已提交
453 454
}

A
Adam 已提交
455 456
template <>
inline void AppendKey(std::string* key,
457
                      const dnnl::memory::format_tag& format) {
A
Adam 已提交
458 459 460 461 462
  key->append(std::to_string(static_cast<int>(format)));
}

template <>
inline void AppendKey(std::string* key,
463
                      const dnnl::memory::data_type& data_type) {
A
Adam 已提交
464 465 466 467
  key->append(std::to_string(static_cast<int>(data_type)));
}

template <>
468
inline void AppendKey(std::string* key, const dnnl::algorithm& algorithm) {
A
Adam 已提交
469 470 471 472 473
  key->append(std::to_string(static_cast<int>(algorithm)));
}

template <>
inline void AppendKey(std::string* key,
474
                      const dnnl::normalization_flags& flags) {
A
Adam 已提交
475 476 477
  key->append(std::to_string(static_cast<int>(flags)));
}

478 479
inline void AppendKey(std::string* key, const std::string& str) {
  key->append(str);
A
Adam 已提交
480 481
}

482
inline void AppendKey(std::string* key, const char* str) { key->append(str); }
A
Adam 已提交
483

A
Adam 已提交
484 485
template <typename T>
inline void AppendKey(std::string* key, const std::vector<T>& dims) {
486
  for (size_t i = 0; i < dims.size(); i++) {
A
Adam 已提交
487 488 489 490
    AppendKey(key, std::to_string(dims[i]));
  }
}

491 492 493 494
// If MKLDNN build and CPU place then register suffix in DeviceContext
inline void AttachPointerHashToMKLDNNKey(void* ptr,
                                         const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
J
Jacek Czaja 已提交
495 496 497 498 499 500 501 502 503 504 505 506 507
    // Static vars will remember first executor and its thread
    // so both of them need to be processed by the same thread within
    // critical section
    static std::mutex static_vars_barrier;
    static_vars_barrier.lock();
    static auto first_exec = ptr;
    static auto first_thread = ThreadIDasStr();
    static_vars_barrier.unlock();

    if (first_exec != ptr) {
      paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix(
          "E" + std::to_string(reinterpret_cast<uintptr_t>(ptr)));
    }
508 509 510
    // Let's register adress of current executor
    paddle::platform::MKLDNNDeviceContext::tls().set_curr_exec(ptr);

J
Jacek Czaja 已提交
511 512 513 514
    // For first thread
    if (first_thread == ThreadIDasStr()) {
      paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key();
    }
515 516 517
  }
}

518
template <typename... ArgTypes>
519 520
inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx,
                             ArgTypes&&... args) {
521
  std::string key;
522
  key.reserve(64);
523
  using expand_type = int[];
524
  expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...};
J
Jacek Czaja 已提交
525
  key += paddle::platform::MKLDNNDeviceContext::tls().get_key_suffix();
526 527 528
  return key;
}

529 530
inline std::string ExtendKeyWithThreadInfoIfNeeded(
    const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) {
J
Jacek Czaja 已提交
531 532
  return (paddle::platform::MKLDNNDeviceContext::tls().is_tid_used_in_key() ==
          true)
533 534 535 536
             ? key + "-t:" + ThreadIDasStr()
             : key;
}

A
Adam 已提交
537 538
inline std::vector<std::vector<int64_t>> ToMkldnnPadding(
    const std::vector<int64_t>& paddings) {
539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558
  if (paddings.size() == 6) {
    int padding_front = paddings[0];
    int padding_back = paddings[1];
    int padding_top = paddings[2];
    int padding_bottom = paddings[3];
    int padding_left = paddings[4];
    int padding_right = paddings[5];

    return {{padding_front, padding_top, padding_left},
            {padding_back, padding_bottom, padding_right}};
  } else {
    int padding_top = paddings[0];
    int padding_bottom = paddings[1];
    int padding_left = paddings[2];
    int padding_right = paddings[3];

    return {{padding_top, padding_left}, {padding_bottom, padding_right}};
  }
}

559 560 561 562 563 564 565 566 567 568 569 570 571
// The function adjusts the vector of weight dimensions for group convolutions
inline void GetGroupConvWeightsTz(std::vector<int64_t>& weights_tz,  // NOLINT
                                  const int groups) {
  if (groups > 1) {
    // if (is_conv3d) [o, i, d, h, w]->[g, o/g, i, d, h, w]
    // else [o, i, h, w] -> [g, o/g, i, h, w]
    weights_tz.push_back(0);
    std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end());
    weights_tz[0] = groups;
    weights_tz[1] = weights_tz[1] / groups;
  }
}

J
Jacek Czaja 已提交
572 573 574 575
inline void RegisterModelLayout(
    std::vector<std::unique_ptr<framework::OperatorBase>>& ops,
    const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
576 577 578 579 580 581
    // If there is already registered NHWC then quit this call
    // not to overwrite setting with analysis of internal "while" op block
    if (platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout() ==
        framework::DataLayout::kNHWC)
      return;

L
Leo Chen 已提交
582
    VLOG(4) << "RegisterModelLayout for mkldnn";
J
Jacek Czaja 已提交
583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606
    auto check_attrib = [](std::unique_ptr<framework::OperatorBase>& op,
                           const std::string& attrib_name) -> bool {
      if (op->HasAttr(attrib_name)) {
        auto data_format = op->Attr<std::string>(attrib_name);
        platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
            data_format.compare("NHWC") == 0 ? framework::DataLayout::kNHWC
                                             : framework::DataLayout::kNCHW);
        return true;
      } else {
        return false;
      }
    };

    for (auto& op : ops) {
      if (check_attrib(op, std::string("data_format"))) {
        return;
      }
      if (check_attrib(op, std::string("data_layout"))) {
        return;
      }
    }
  }
}

607 608 609 610 611
inline bool HasOpINT8DataType(const paddle::framework::OpDesc* op) {
  return (op->GetAttrIfExists<std::string>("mkldnn_data_type") == "int8" ||
          op->GetAttrIfExists<bool>("use_quantizer"));
}

612 613 614 615 616 617 618
inline bool HasOpBFLOAT16DataType(const paddle::framework::OpDesc* op) {
  return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "bfloat16";
}

inline bool HasOpFLOAT32DataType(const paddle::framework::OpDesc* op) {
  return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "float32";
}
A
Adam Osewski 已提交
619

A
Adam 已提交
620 621
enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP };

A
Adam Osewski 已提交
622 623 624 625 626
template <typename T>
bool constexpr is_int8() {
  return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}

T
tensor-tang 已提交
627 628
}  // namespace platform
}  // namespace paddle