mkldnn_helper.h 19.6 KB
Newer Older
1
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
T
tensor-tang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

16
#include <algorithm>
J
Jacek Czaja 已提交
17
#include <iostream>
P
Physher 已提交
18
#include <memory>
J
Jacek Czaja 已提交
19
#include <sstream>
G
gongweibao 已提交
20
#include <string>
21
#include <utility>
22
#include <vector>
23
#include "dnnl.hpp"
24
#include "paddle/fluid/framework/operator.h"
M
mozga-intel 已提交
25
#include "paddle/fluid/platform/place.h"
26
#include "paddle/fluid/platform/profiler.h"
T
tensor-tang 已提交
27
namespace paddle {
28
#ifdef PADDLE_WITH_MKLDNN
29
using MKLDNNMemoryFormat = dnnl::memory::format_tag;
30
#endif
T
tensor-tang 已提交
31 32
namespace platform {

33 34 35 36 37 38
using MKLDNNStream = dnnl::stream;
using MKLDNNEngine = dnnl::engine;
using MKLDNNMemory = dnnl::memory;
using MKLDNNMemoryDescriptor = dnnl::memory::desc;
using MKLDNNPrimitive = dnnl::primitive;
using MKLDNNPrimitiveDesc = dnnl::handle<dnnl_primitive_desc_t>;
T
tensor-tang 已提交
39

40 41 42 43 44
typedef std::unique_ptr<MKLDNNStream> MKLDNNStreamPtr;
typedef std::unique_ptr<MKLDNNEngine> MKLDNNEnginePtr;
typedef std::unique_ptr<MKLDNNMemory> MKLDNNMemoryPtr;
typedef std::unique_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr;
typedef std::unique_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr;
T
tensor-tang 已提交
45

46 47 48 49 50
template <typename Type>
void* to_void_cast(const Type* t) {
  return static_cast<void*>(const_cast<Type*>(t));
}

K
Krzysztof Binias 已提交
51 52 53 54 55
template <typename Type>
void* to_void_reinterpret_cast(const Type* t) {
  return reinterpret_cast<void*>(const_cast<Type*>(t));
}

56 57 58 59 60 61 62 63 64
template <class Type>
using tf_desc = typename Type::desc;

template <class Type>
using tf_pd = typename Type::primitive_desc;

template <typename Type, typename Engine, typename... Args>
std::shared_ptr<tf_pd<Type>> MKLDNNFwdPrimitiveDesc(const Engine& e,
                                                    Args&&... args) {
65
  auto desc = tf_desc<Type>(dnnl::prop_kind::forward, (args)...);
66 67 68 69 70 71 72 73 74 75 76
  auto pd = new tf_pd<Type>(desc, e);
  return std::shared_ptr<tf_pd<Type>>(pd);
}

template <typename Type, typename Engine, typename Primitive, typename... Args>
tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
                                   Args&&... args) {
  auto desc = tf_desc<Type>(args...);
  return tf_pd<Type>(desc, e, p);
}

77 78 79
inline void MatchShapeToLayout(framework::Tensor* tensor_in,
                               framework::DataLayout from,
                               framework::DataLayout to) {
80 81 82
  // In these data layouts, channel dimension is either on 2nd position: nChw or
  // at last nhwC, so for dim==2 these layouts are the same and nothing should
  // be done. Similarly for dim==1 when you have just one possible combination.
83 84 85 86
  if (tensor_in->dims().size() < 3) {
    return;
  }

J
Jacek Czaja 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
  auto print_dims = [](const std::vector<int>& dims) {
    std::ostringstream oss;

    if (!dims.empty()) {
      oss << "[";
      // Convert all but the last element to avoid a trailing ","
      std::copy(dims.begin(), dims.end() - 1,
                std::ostream_iterator<int>(oss, ","));

      // Now add the last element with no delimiter
      oss << dims.back() << "]";
    }

    return oss.str();
  };

103 104 105 106 107 108
  switch (from) {
    case framework::DataLayout::kMKLDNN:
      if (to == framework::DataLayout::kNHWC) {
        auto dims = framework::vectorize<int>(tensor_in->dims());
        std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
        tensor_in->Resize(framework::make_ddim(dims));
J
Jacek Czaja 已提交
109 110
        VLOG(3) << "Rotating Shape from: kMKLDNN to: kNHWC output_shape"
                << print_dims(dims);
111 112 113 114 115 116 117
      }
      break;
    case framework::DataLayout::kNHWC:
      if (to == framework::DataLayout::kMKLDNN) {
        auto dims = framework::vectorize<int>(tensor_in->dims());
        std::rotate(dims.begin() + 1, dims.end() - 1, dims.end());
        tensor_in->Resize(framework::make_ddim(dims));
J
Jacek Czaja 已提交
118 119
        VLOG(3) << "Rotating Shape from: kNHWC to: kMKLDNN output_shape"
                << print_dims(dims);
120 121 122 123 124 125 126
      }
      break;
    default:
      break;
  }
}

127 128 129 130 131
struct mkldnn_dummy_primitive {
  struct primitive_desc {};
  struct desc {};
};

132 133 134 135
inline dnnl::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
                                        dnnl::memory::data_type data_type,
                                        MKLDNNMemoryFormat format) {
  return dnnl::memory::desc({dims}, data_type, format);
136 137
}

138 139
inline void ClearMKLDNNCache(const platform::Place& place,
                             void* ptr = nullptr) {
140 141 142 143 144
  // Clear mkl-dnn cache,
  if (platform::is_cpu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::MKLDNNDeviceContext* dev_ctx =
        (platform::MKLDNNDeviceContext*)pool.Get(place);
145
    dev_ctx->ResetBlobMap(ptr);
146 147 148 149 150
    platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
        paddle::framework::DataLayout::kNCHW);
  }
}

151 152 153 154 155 156 157 158 159 160
inline void DontClearMKLDNNCache(const platform::Place& place) {
  // Clear mkl-dnn cache,
  if (platform::is_cpu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::MKLDNNDeviceContext* dev_ctx =
        (platform::MKLDNNDeviceContext*)pool.Get(place);
    dev_ctx->BlockNextCacheClearing();
  }
}

161
template <typename Type>
162 163
dnnl::memory::data_type MKLDNNGetDataType() {
  return dnnl::memory::data_type::undef;
164 165 166
}

template <>
167 168
inline dnnl::memory::data_type MKLDNNGetDataType<float>() {
  return dnnl::memory::data_type::f32;
169 170
}
template <>
171 172
inline dnnl::memory::data_type MKLDNNGetDataType<int32_t>() {
  return dnnl::memory::data_type::s32;
173
}
P
Physher 已提交
174
template <>
175 176
inline dnnl::memory::data_type MKLDNNGetDataType<int8_t>() {
  return dnnl::memory::data_type::s8;
P
Physher 已提交
177 178
}
template <>
179 180
inline dnnl::memory::data_type MKLDNNGetDataType<uint8_t>() {
  return dnnl::memory::data_type::u8;
P
Physher 已提交
181 182
}

183
template <>
184 185
inline dnnl::memory::data_type MKLDNNGetDataType<paddle::platform::bfloat16>() {
  return dnnl::memory::data_type::bf16;
186 187
}

188 189 190
inline void Reorder(dnnl::memory src, dnnl::memory dst,
                    const dnnl::engine& engine) {
  auto reorder_prim = dnnl::reorder(src, dst);
191
  auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
192 193
  platform::RecordEvent record_reorder("int_reorder",
                                       platform::EventRole::kUniqueOp);
A
Adam 已提交
194 195
  reorder_prim.execute(astream, src, dst);
  astream.wait();
M
mozga-intel 已提交
196 197
}

198
inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
A
Adam 已提交
199 200 201 202 203 204 205
  auto ndims = mem_desc.data.ndims;
  auto strides = mem_desc.data.format_desc.blocking.strides;
  auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks;
  auto inner_blks = mem_desc.data.format_desc.blocking.inner_blks;
  auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs;

  if (ndims == 1) {
206
    return dnnl::memory::format_tag::x;
A
Adam 已提交
207 208 209
  } else if (ndims == 2) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1]) {
210
        return dnnl::memory::format_tag::nc;
A
Adam 已提交
211
      } else {
212
        return dnnl::memory::format_tag::cn;
A
Adam 已提交
213 214 215 216 217
      }
    }
  } else if (ndims == 3) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
218
        return dnnl::memory::format_tag::ncw;
A
Adam 已提交
219
      } else if (strides[1] >= strides[0] && strides[0] >= strides[2]) {
220
        return dnnl::memory::format_tag::ntc;
A
Adam 已提交
221
      } else {
222
        return dnnl::memory::format_tag::nwc;
A
Adam 已提交
223 224 225 226 227 228
      }
    }
  } else if (ndims == 4) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3]) {
229
        return dnnl::memory::format_tag::nchw;
230 231
      } else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
                 strides[1] >= strides[0]) {
232
        return dnnl::memory::format_tag::cdba;
A
Adam 已提交
233
      } else {
234
        return dnnl::memory::format_tag::nhwc;
A
Adam 已提交
235 236 237
      }
    } else if (inner_nblks == 1) {
      if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
238
        return dnnl::memory::format_tag::nChw16c;
A
Adam 已提交
239
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
240
        return dnnl::memory::format_tag::nChw8c;
A
Adam 已提交
241 242 243
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[1]) {
244
          return dnnl::memory::format_tag::Acdb8a;
A
Adam 已提交
245 246
        }
      } else if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
247
        return dnnl::memory::format_tag::nChw4c;
A
Adam 已提交
248 249 250
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[1]) {
251
          return dnnl::memory::format_tag::Acdb16a;
A
Adam 已提交
252 253 254 255 256
        }
      }
    } else if (inner_nblks == 2) {
      if (inner_blks[0] == 16 && inner_blks[1] == 16) {
        if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
257
          return dnnl::memory::format_tag::OIhw16i16o;
A
Adam 已提交
258 259 260
        }
      } else if (inner_blks[0] == 8 && inner_blks[1] == 8) {
        if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
261
          return dnnl::memory::format_tag::OIhw8i8o;
A
Adam 已提交
262 263 264 265 266 267 268
        }
      }
    }
  } else if (ndims == 5) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3] && strides[3] >= strides[4]) {
269
        return dnnl::memory::format_tag::ncdhw;
A
Adam 已提交
270
      } else {
271
        return dnnl::memory::format_tag::ndhwc;
A
Adam 已提交
272 273 274 275 276
      }
    } else if (inner_nblks == 1) {
      if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[4] && strides[4] >= strides[1]) {
277
          return dnnl::memory::format_tag::Acdeb8a;
A
Adam 已提交
278
        }
279 280
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
281
          return dnnl::memory::format_tag::Abcde8a;
282
        }
A
Adam 已提交
283 284 285
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
286
          return dnnl::memory::format_tag::aBcde8b;
A
Adam 已提交
287 288 289 290
        }
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[4] && strides[4] >= strides[1]) {
291
          return dnnl::memory::format_tag::Acdeb16a;
A
Adam 已提交
292
        }
293 294
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
295
          return dnnl::memory::format_tag::Abcde16a;
296
        }
A
Adam 已提交
297 298 299
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
300
          return dnnl::memory::format_tag::aBcde16b;
A
Adam 已提交
301 302 303 304 305 306 307 308
        }
      }
    }
  } else if (ndims == 6) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3] && strides[3] >= strides[4] &&
          strides[4] >= strides[5]) {
309
        return dnnl::memory::format_tag::abcdef;
A
Adam 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
      }
    }
  }
  // DEBUG CODE - KEEP UNTILL TENSOR.MEMORY_DESC IMPLEMENTED
  // std::cout<<"@@@@@@@@@@ UNDEFINED FORMAT @@@@@@@@@@@@@@@@@@@"<<std::endl;
  // std::cout<<"NDIMS: "<<ndims<<std::endl;
  // std::cout<<"INNER_NBLKS: "<<inner_nblks<<std::endl;
  // for (int i=0;i<ndims;++i) {
  //   std::cout<<"STRIDE["<<i<<"]: "<<strides[i]<<std::endl;
  // }
  // for (int i=0;i<inner_nblks;++i) {
  //   std::cout<<"INNER_BLKS["<<i<<"]: "<<inner_blks[i]<<std::endl;
  // }
  // for (int i=0;i<inner_nblks;++i) {
  //   std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl;
  // }
326
  return dnnl::memory::format_tag::undef;
M
mozga-intel 已提交
327 328
}

329
inline dnnl::memory::format_tag GetMKLDNNFormat(const dnnl::memory memory) {
A
Adam 已提交
330 331
  auto mem_desc = memory.get_desc();
  return GetMKLDNNFormat(mem_desc);
332 333
}

334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370
inline mkldnn::memory::format_tag GetPlainMKLDNNFormat(int tensor_rank) {
  switch (tensor_rank) {
    case 1:
      return mkldnn::memory::format_tag::a;
      break;
    case 2:
      return mkldnn::memory::format_tag::ab;
      break;
    case 3:
      return mkldnn::memory::format_tag::abc;
      break;
    case 4:
      return mkldnn::memory::format_tag::abcd;
      break;
    case 5:
      return mkldnn::memory::format_tag::abcde;
      break;
    case 6:
      return mkldnn::memory::format_tag::abcdef;
      break;
    case 7:
      return mkldnn::memory::format_tag::abcdefg;
      break;
    case 8:
      return mkldnn::memory::format_tag::abcdefgh;
      break;
    case 9:
      return mkldnn::memory::format_tag::abcdefghi;
      break;
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Paddle support tensors with rank in range <1, 9>, but received "
          "tensor with rank: %d",
          tensor_rank));
  }
}

371 372
inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size,
                                              MKLDNNMemoryFormat data_format) {
373
  if (dims_size == 1) {
374
    return MKLDNNMemoryFormat::x;
375
  } else if (dims_size == 2) {
376
    return MKLDNNMemoryFormat::nc;
377
  } else if (dims_size == 3) {
378 379 380 381
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::ncw;
    } else if (data_format == MKLDNNMemoryFormat::nhwc) {
      return MKLDNNMemoryFormat::nwc;
382
    }
383
  } else if (dims_size == 4) {
384 385
    if (data_format == MKLDNNMemoryFormat::goihw) {
      return MKLDNNMemoryFormat::oihw;
386
    }
387
  } else if (dims_size == 5) {
388 389
    if (data_format == MKLDNNMemoryFormat::goidhw) {
      return MKLDNNMemoryFormat::oidhw;
390
    }
391 392 393 394
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::ncdhw;
    } else if (data_format == MKLDNNMemoryFormat::nhwc) {
      return MKLDNNMemoryFormat::ndhwc;
395
    }
396 397
  } else if (dims_size == 6) {
    return MKLDNNMemoryFormat::abcdef;
398 399 400 401
  }
  return data_format;
}

402
inline MKLDNNMemoryFormat data_format_to_memory_format(
403 404 405
    const std::string& data_format) {
  switch (framework::StringToDataLayout(data_format)) {
    case framework::DataLayout::kNHWC:
406
      return MKLDNNMemoryFormat::nhwc;
407
    case framework::DataLayout::kNCHW:
408
      return MKLDNNMemoryFormat::nchw;
409
    default:
410
      return MKLDNNMemoryFormat::any;
411 412 413
  }
}

414
inline MKLDNNMemoryFormat StringToMKLDNNFormat(std::string* format) {
415 416 417
  std::transform(format->begin(), format->end(), format->begin(), ::tolower);

  if (!format->compare("nchw")) {
418
    return MKLDNNMemoryFormat::nchw;
419
  } else if (!format->compare("nchw16c")) {
420
    return MKLDNNMemoryFormat::nChw16c;
421
  } else if (!format->compare("nchw8c")) {
422
    return MKLDNNMemoryFormat::nChw8c;
423
  } else if (!format->compare("nhwc")) {
424
    return MKLDNNMemoryFormat::nhwc;
425
  } else {
426
    return MKLDNNMemoryFormat::any;
427 428 429
  }
}

A
Adam 已提交
430 431 432 433 434
inline std::string ThreadIDasStr(void) {
  return std::to_string(
      std::hash<std::thread::id>()(std::this_thread::get_id()));
}

435 436 437
template <typename T>
inline void AppendKey(std::string* key, const T& num) {
  key->append(std::to_string(num));
A
Adam 已提交
438 439
}

A
Adam 已提交
440 441
template <>
inline void AppendKey(std::string* key,
442
                      const dnnl::memory::format_tag& format) {
A
Adam 已提交
443 444 445 446 447
  key->append(std::to_string(static_cast<int>(format)));
}

template <>
inline void AppendKey(std::string* key,
448
                      const dnnl::memory::data_type& data_type) {
A
Adam 已提交
449 450 451 452
  key->append(std::to_string(static_cast<int>(data_type)));
}

template <>
453
inline void AppendKey(std::string* key, const dnnl::algorithm& algorithm) {
A
Adam 已提交
454 455 456 457 458
  key->append(std::to_string(static_cast<int>(algorithm)));
}

template <>
inline void AppendKey(std::string* key,
459
                      const dnnl::normalization_flags& flags) {
A
Adam 已提交
460 461 462
  key->append(std::to_string(static_cast<int>(flags)));
}

463 464
inline void AppendKey(std::string* key, const std::string& str) {
  key->append(str);
A
Adam 已提交
465 466
}

467
inline void AppendKey(std::string* key, const char* str) { key->append(str); }
A
Adam 已提交
468

A
Adam 已提交
469 470
template <typename T>
inline void AppendKey(std::string* key, const std::vector<T>& dims) {
471
  for (size_t i = 0; i < dims.size(); i++) {
A
Adam 已提交
472 473 474 475
    AppendKey(key, std::to_string(dims[i]));
  }
}

476 477 478 479
// If MKLDNN build and CPU place then register suffix in DeviceContext
inline void AttachPointerHashToMKLDNNKey(void* ptr,
                                         const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
J
Jacek Czaja 已提交
480 481 482 483 484 485 486 487 488 489 490 491 492
    // Static vars will remember first executor and its thread
    // so both of them need to be processed by the same thread within
    // critical section
    static std::mutex static_vars_barrier;
    static_vars_barrier.lock();
    static auto first_exec = ptr;
    static auto first_thread = ThreadIDasStr();
    static_vars_barrier.unlock();

    if (first_exec != ptr) {
      paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix(
          "E" + std::to_string(reinterpret_cast<uintptr_t>(ptr)));
    }
493 494 495
    // Let's register adress of current executor
    paddle::platform::MKLDNNDeviceContext::tls().set_curr_exec(ptr);

J
Jacek Czaja 已提交
496 497 498 499
    // For first thread
    if (first_thread == ThreadIDasStr()) {
      paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key();
    }
500 501 502
  }
}

503
template <typename... ArgTypes>
504 505
inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx,
                             ArgTypes&&... args) {
506
  std::string key;
507
  key.reserve(64);
508
  using expand_type = int[];
509
  expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...};
J
Jacek Czaja 已提交
510
  key += paddle::platform::MKLDNNDeviceContext::tls().get_key_suffix();
511 512 513
  return key;
}

514 515
inline std::string ExtendKeyWithThreadInfoIfNeeded(
    const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) {
J
Jacek Czaja 已提交
516 517
  return (paddle::platform::MKLDNNDeviceContext::tls().is_tid_used_in_key() ==
          true)
518 519 520 521
             ? key + "-t:" + ThreadIDasStr()
             : key;
}

A
Adam 已提交
522 523
inline std::vector<std::vector<int64_t>> ToMkldnnPadding(
    const std::vector<int64_t>& paddings) {
524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
  if (paddings.size() == 6) {
    int padding_front = paddings[0];
    int padding_back = paddings[1];
    int padding_top = paddings[2];
    int padding_bottom = paddings[3];
    int padding_left = paddings[4];
    int padding_right = paddings[5];

    return {{padding_front, padding_top, padding_left},
            {padding_back, padding_bottom, padding_right}};
  } else {
    int padding_top = paddings[0];
    int padding_bottom = paddings[1];
    int padding_left = paddings[2];
    int padding_right = paddings[3];

    return {{padding_top, padding_left}, {padding_bottom, padding_right}};
  }
}

544 545 546 547 548 549 550 551 552 553 554 555 556
// The function adjusts the vector of weight dimensions for group convolutions
inline void GetGroupConvWeightsTz(std::vector<int64_t>& weights_tz,  // NOLINT
                                  const int groups) {
  if (groups > 1) {
    // if (is_conv3d) [o, i, d, h, w]->[g, o/g, i, d, h, w]
    // else [o, i, h, w] -> [g, o/g, i, h, w]
    weights_tz.push_back(0);
    std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end());
    weights_tz[0] = groups;
    weights_tz[1] = weights_tz[1] / groups;
  }
}

557 558 559 560 561
inline bool HasOpINT8DataType(const paddle::framework::OpDesc* op) {
  return (op->GetAttrIfExists<std::string>("mkldnn_data_type") == "int8" ||
          op->GetAttrIfExists<bool>("use_quantizer"));
}

562 563 564 565 566 567 568
inline bool HasOpBFLOAT16DataType(const paddle::framework::OpDesc* op) {
  return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "bfloat16";
}

inline bool HasOpFLOAT32DataType(const paddle::framework::OpDesc* op) {
  return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "float32";
}
A
Adam Osewski 已提交
569

A
Adam 已提交
570 571
enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP };

A
Adam Osewski 已提交
572 573 574 575 576
template <typename T>
bool constexpr is_int8() {
  return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}

T
tensor-tang 已提交
577 578
}  // namespace platform
}  // namespace paddle