mkldnn_helper.h 19.8 KB
Newer Older
1
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
T
tensor-tang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

16
#include <algorithm>
J
Jacek Czaja 已提交
17
#include <iostream>
P
Physher 已提交
18
#include <memory>
J
Jacek Czaja 已提交
19
#include <sstream>
G
gongweibao 已提交
20
#include <string>
21
#include <utility>
22
#include <vector>
23
#include "dnnl.hpp"
24
#include "paddle/fluid/framework/operator.h"
M
mozga-intel 已提交
25
#include "paddle/fluid/platform/place.h"
26
#include "paddle/fluid/platform/profiler.h"
T
tensor-tang 已提交
27
namespace paddle {
28
#ifdef PADDLE_WITH_MKLDNN
29
using MKLDNNMemoryFormat = dnnl::memory::format_tag;
30
#endif
T
tensor-tang 已提交
31 32
namespace platform {

33 34 35 36 37 38
using MKLDNNStream = dnnl::stream;
using MKLDNNEngine = dnnl::engine;
using MKLDNNMemory = dnnl::memory;
using MKLDNNMemoryDescriptor = dnnl::memory::desc;
using MKLDNNPrimitive = dnnl::primitive;
using MKLDNNPrimitiveDesc = dnnl::handle<dnnl_primitive_desc_t>;
T
tensor-tang 已提交
39

40 41 42 43 44
typedef std::unique_ptr<MKLDNNStream> MKLDNNStreamPtr;
typedef std::unique_ptr<MKLDNNEngine> MKLDNNEnginePtr;
typedef std::unique_ptr<MKLDNNMemory> MKLDNNMemoryPtr;
typedef std::unique_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr;
typedef std::unique_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr;
T
tensor-tang 已提交
45

46 47 48 49 50
template <typename Type>
void* to_void_cast(const Type* t) {
  return static_cast<void*>(const_cast<Type*>(t));
}

K
Krzysztof Binias 已提交
51 52 53 54 55
template <typename Type>
void* to_void_reinterpret_cast(const Type* t) {
  return reinterpret_cast<void*>(const_cast<Type*>(t));
}

56 57 58 59 60 61 62 63 64
template <class Type>
using tf_desc = typename Type::desc;

template <class Type>
using tf_pd = typename Type::primitive_desc;

template <typename Type, typename Engine, typename... Args>
std::shared_ptr<tf_pd<Type>> MKLDNNFwdPrimitiveDesc(const Engine& e,
                                                    Args&&... args) {
65
  auto desc = tf_desc<Type>(dnnl::prop_kind::forward, (args)...);
66 67 68 69 70 71 72 73 74 75 76
  auto pd = new tf_pd<Type>(desc, e);
  return std::shared_ptr<tf_pd<Type>>(pd);
}

template <typename Type, typename Engine, typename Primitive, typename... Args>
tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
                                   Args&&... args) {
  auto desc = tf_desc<Type>(args...);
  return tf_pd<Type>(desc, e, p);
}

77 78 79
inline void MatchShapeToLayout(framework::Tensor* tensor_in,
                               framework::DataLayout from,
                               framework::DataLayout to) {
80 81 82
  // In these data layouts, channel dimension is either on 2nd position: nChw or
  // at last nhwC, so for dim==2 these layouts are the same and nothing should
  // be done. Similarly for dim==1 when you have just one possible combination.
83 84 85 86
  if (tensor_in->dims().size() < 3) {
    return;
  }

J
Jacek Czaja 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
  auto print_dims = [](const std::vector<int>& dims) {
    std::ostringstream oss;

    if (!dims.empty()) {
      oss << "[";
      // Convert all but the last element to avoid a trailing ","
      std::copy(dims.begin(), dims.end() - 1,
                std::ostream_iterator<int>(oss, ","));

      // Now add the last element with no delimiter
      oss << dims.back() << "]";
    }

    return oss.str();
  };

103 104 105 106 107 108
  switch (from) {
    case framework::DataLayout::kMKLDNN:
      if (to == framework::DataLayout::kNHWC) {
        auto dims = framework::vectorize<int>(tensor_in->dims());
        std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
        tensor_in->Resize(framework::make_ddim(dims));
J
Jacek Czaja 已提交
109 110
        VLOG(3) << "Rotating Shape from: kMKLDNN to: kNHWC output_shape"
                << print_dims(dims);
111 112 113 114 115 116 117
      }
      break;
    case framework::DataLayout::kNHWC:
      if (to == framework::DataLayout::kMKLDNN) {
        auto dims = framework::vectorize<int>(tensor_in->dims());
        std::rotate(dims.begin() + 1, dims.end() - 1, dims.end());
        tensor_in->Resize(framework::make_ddim(dims));
J
Jacek Czaja 已提交
118 119
        VLOG(3) << "Rotating Shape from: kNHWC to: kMKLDNN output_shape"
                << print_dims(dims);
120 121 122 123 124 125 126
      }
      break;
    default:
      break;
  }
}

127 128 129 130 131
struct mkldnn_dummy_primitive {
  struct primitive_desc {};
  struct desc {};
};

132 133 134 135
inline dnnl::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
                                        dnnl::memory::data_type data_type,
                                        MKLDNNMemoryFormat format) {
  return dnnl::memory::desc({dims}, data_type, format);
136 137
}

138 139
inline void ClearMKLDNNCache(const platform::Place& place,
                             void* ptr = nullptr) {
140 141 142 143 144
  // Clear mkl-dnn cache,
  if (platform::is_cpu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::MKLDNNDeviceContext* dev_ctx =
        (platform::MKLDNNDeviceContext*)pool.Get(place);
145
    dev_ctx->ResetBlobMap(ptr);
146 147 148 149 150
    platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
        paddle::framework::DataLayout::kNCHW);
  }
}

151 152 153 154 155 156 157 158 159 160
inline void DontClearMKLDNNCache(const platform::Place& place) {
  // Clear mkl-dnn cache,
  if (platform::is_cpu_place(place)) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::MKLDNNDeviceContext* dev_ctx =
        (platform::MKLDNNDeviceContext*)pool.Get(place);
    dev_ctx->BlockNextCacheClearing();
  }
}

161
template <typename Type>
162 163
dnnl::memory::data_type MKLDNNGetDataType() {
  return dnnl::memory::data_type::undef;
164 165 166
}

template <>
167 168
inline dnnl::memory::data_type MKLDNNGetDataType<float>() {
  return dnnl::memory::data_type::f32;
169 170
}
template <>
171 172
inline dnnl::memory::data_type MKLDNNGetDataType<int32_t>() {
  return dnnl::memory::data_type::s32;
173
}
P
Physher 已提交
174
template <>
175 176
inline dnnl::memory::data_type MKLDNNGetDataType<int8_t>() {
  return dnnl::memory::data_type::s8;
P
Physher 已提交
177 178
}
template <>
179 180
inline dnnl::memory::data_type MKLDNNGetDataType<uint8_t>() {
  return dnnl::memory::data_type::u8;
P
Physher 已提交
181 182
}

183
template <>
184 185
inline dnnl::memory::data_type MKLDNNGetDataType<paddle::platform::bfloat16>() {
  return dnnl::memory::data_type::bf16;
186 187
}

188 189 190
inline void Reorder(dnnl::memory src, dnnl::memory dst,
                    const dnnl::engine& engine) {
  auto reorder_prim = dnnl::reorder(src, dst);
191
  auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
192 193
  platform::RecordEvent record_reorder("int_reorder",
                                       platform::EventRole::kUniqueOp);
A
Adam 已提交
194 195
  reorder_prim.execute(astream, src, dst);
  astream.wait();
M
mozga-intel 已提交
196 197
}

198
inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
A
Adam 已提交
199 200 201 202 203 204 205
  auto ndims = mem_desc.data.ndims;
  auto strides = mem_desc.data.format_desc.blocking.strides;
  auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks;
  auto inner_blks = mem_desc.data.format_desc.blocking.inner_blks;
  auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs;

  if (ndims == 1) {
206
    return dnnl::memory::format_tag::x;
A
Adam 已提交
207 208 209
  } else if (ndims == 2) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1]) {
210
        return dnnl::memory::format_tag::nc;
A
Adam 已提交
211
      } else {
212
        return dnnl::memory::format_tag::cn;
A
Adam 已提交
213 214 215 216 217
      }
    }
  } else if (ndims == 3) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
218
        return dnnl::memory::format_tag::ncw;
A
Adam 已提交
219
      } else if (strides[1] >= strides[0] && strides[0] >= strides[2]) {
220
        return dnnl::memory::format_tag::ntc;
A
Adam 已提交
221
      } else {
222
        return dnnl::memory::format_tag::nwc;
A
Adam 已提交
223 224 225 226 227 228
      }
    }
  } else if (ndims == 4) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3]) {
229
        return dnnl::memory::format_tag::nchw;
230 231
      } else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
                 strides[1] >= strides[0]) {
232
        return dnnl::memory::format_tag::cdba;
233 234 235
      } else if (strides[3] >= strides[2] && strides[2] >= strides[0] &&
                 strides[0] >= strides[1]) {
        return dnnl::memory::format_tag::dcab;
A
Adam 已提交
236
      } else {
237
        return dnnl::memory::format_tag::nhwc;
A
Adam 已提交
238 239 240
      }
    } else if (inner_nblks == 1) {
      if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
241
        return dnnl::memory::format_tag::nChw16c;
A
Adam 已提交
242
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
243
        return dnnl::memory::format_tag::nChw8c;
A
Adam 已提交
244 245 246
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[1]) {
247
          return dnnl::memory::format_tag::Acdb8a;
A
Adam 已提交
248 249
        }
      } else if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
250
        return dnnl::memory::format_tag::nChw4c;
A
Adam 已提交
251 252 253
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[1]) {
254
          return dnnl::memory::format_tag::Acdb16a;
A
Adam 已提交
255 256 257 258 259
        }
      }
    } else if (inner_nblks == 2) {
      if (inner_blks[0] == 16 && inner_blks[1] == 16) {
        if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
260
          return dnnl::memory::format_tag::OIhw16i16o;
A
Adam 已提交
261 262 263
        }
      } else if (inner_blks[0] == 8 && inner_blks[1] == 8) {
        if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
264
          return dnnl::memory::format_tag::OIhw8i8o;
A
Adam 已提交
265 266 267 268 269 270 271
        }
      }
    }
  } else if (ndims == 5) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3] && strides[3] >= strides[4]) {
272
        return dnnl::memory::format_tag::ncdhw;
A
Adam 已提交
273
      } else {
274
        return dnnl::memory::format_tag::ndhwc;
A
Adam 已提交
275 276 277 278 279
      }
    } else if (inner_nblks == 1) {
      if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[4] && strides[4] >= strides[1]) {
280
          return dnnl::memory::format_tag::Acdeb8a;
A
Adam 已提交
281
        }
282 283
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
284
          return dnnl::memory::format_tag::Abcde8a;
285
        }
A
Adam 已提交
286 287 288
      } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
289
          return dnnl::memory::format_tag::aBcde8b;
A
Adam 已提交
290 291 292 293
        }
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
        if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
            strides[3] >= strides[4] && strides[4] >= strides[1]) {
294
          return dnnl::memory::format_tag::Acdeb16a;
A
Adam 已提交
295
        }
296 297
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
298
          return dnnl::memory::format_tag::Abcde16a;
299
        }
A
Adam 已提交
300 301 302
      } else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
        if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
            strides[2] >= strides[3] && strides[3] >= strides[4]) {
303
          return dnnl::memory::format_tag::aBcde16b;
A
Adam 已提交
304 305 306 307 308 309 310 311
        }
      }
    }
  } else if (ndims == 6) {
    if (inner_nblks == 0) {
      if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
          strides[2] >= strides[3] && strides[3] >= strides[4] &&
          strides[4] >= strides[5]) {
312
        return dnnl::memory::format_tag::abcdef;
A
Adam 已提交
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328
      }
    }
  }
  // DEBUG CODE - KEEP UNTILL TENSOR.MEMORY_DESC IMPLEMENTED
  // std::cout<<"@@@@@@@@@@ UNDEFINED FORMAT @@@@@@@@@@@@@@@@@@@"<<std::endl;
  // std::cout<<"NDIMS: "<<ndims<<std::endl;
  // std::cout<<"INNER_NBLKS: "<<inner_nblks<<std::endl;
  // for (int i=0;i<ndims;++i) {
  //   std::cout<<"STRIDE["<<i<<"]: "<<strides[i]<<std::endl;
  // }
  // for (int i=0;i<inner_nblks;++i) {
  //   std::cout<<"INNER_BLKS["<<i<<"]: "<<inner_blks[i]<<std::endl;
  // }
  // for (int i=0;i<inner_nblks;++i) {
  //   std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl;
  // }
329
  return dnnl::memory::format_tag::undef;
M
mozga-intel 已提交
330 331
}

332
inline dnnl::memory::format_tag GetMKLDNNFormat(const dnnl::memory memory) {
A
Adam 已提交
333 334
  auto mem_desc = memory.get_desc();
  return GetMKLDNNFormat(mem_desc);
335 336
}

337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
inline mkldnn::memory::format_tag GetPlainMKLDNNFormat(int tensor_rank) {
  switch (tensor_rank) {
    case 1:
      return mkldnn::memory::format_tag::a;
      break;
    case 2:
      return mkldnn::memory::format_tag::ab;
      break;
    case 3:
      return mkldnn::memory::format_tag::abc;
      break;
    case 4:
      return mkldnn::memory::format_tag::abcd;
      break;
    case 5:
      return mkldnn::memory::format_tag::abcde;
      break;
    case 6:
      return mkldnn::memory::format_tag::abcdef;
      break;
    case 7:
      return mkldnn::memory::format_tag::abcdefg;
      break;
    case 8:
      return mkldnn::memory::format_tag::abcdefgh;
      break;
    case 9:
      return mkldnn::memory::format_tag::abcdefghi;
      break;
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Paddle support tensors with rank in range <1, 9>, but received "
          "tensor with rank: %d",
          tensor_rank));
  }
}

374 375
inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size,
                                              MKLDNNMemoryFormat data_format) {
376
  if (dims_size == 1) {
377
    return MKLDNNMemoryFormat::x;
378
  } else if (dims_size == 2) {
379
    return MKLDNNMemoryFormat::nc;
380
  } else if (dims_size == 3) {
381 382 383 384
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::ncw;
    } else if (data_format == MKLDNNMemoryFormat::nhwc) {
      return MKLDNNMemoryFormat::nwc;
385
    }
386
  } else if (dims_size == 4) {
387 388
    if (data_format == MKLDNNMemoryFormat::goihw) {
      return MKLDNNMemoryFormat::oihw;
389
    }
390
  } else if (dims_size == 5) {
391 392
    if (data_format == MKLDNNMemoryFormat::goidhw) {
      return MKLDNNMemoryFormat::oidhw;
393
    }
394 395 396 397
    if (data_format == MKLDNNMemoryFormat::nchw) {
      return MKLDNNMemoryFormat::ncdhw;
    } else if (data_format == MKLDNNMemoryFormat::nhwc) {
      return MKLDNNMemoryFormat::ndhwc;
398
    }
399 400
  } else if (dims_size == 6) {
    return MKLDNNMemoryFormat::abcdef;
401 402 403 404
  }
  return data_format;
}

405
inline MKLDNNMemoryFormat data_format_to_memory_format(
406 407 408
    const std::string& data_format) {
  switch (framework::StringToDataLayout(data_format)) {
    case framework::DataLayout::kNHWC:
409
      return MKLDNNMemoryFormat::nhwc;
410
    case framework::DataLayout::kNCHW:
411
      return MKLDNNMemoryFormat::nchw;
412
    default:
413
      return MKLDNNMemoryFormat::any;
414 415 416
  }
}

417
inline MKLDNNMemoryFormat StringToMKLDNNFormat(std::string* format) {
418 419 420
  std::transform(format->begin(), format->end(), format->begin(), ::tolower);

  if (!format->compare("nchw")) {
421
    return MKLDNNMemoryFormat::nchw;
422
  } else if (!format->compare("nchw16c")) {
423
    return MKLDNNMemoryFormat::nChw16c;
424
  } else if (!format->compare("nchw8c")) {
425
    return MKLDNNMemoryFormat::nChw8c;
426
  } else if (!format->compare("nhwc")) {
427
    return MKLDNNMemoryFormat::nhwc;
428
  } else {
429
    return MKLDNNMemoryFormat::any;
430 431 432
  }
}

A
Adam 已提交
433 434 435 436 437
inline std::string ThreadIDasStr(void) {
  return std::to_string(
      std::hash<std::thread::id>()(std::this_thread::get_id()));
}

438 439 440
template <typename T>
inline void AppendKey(std::string* key, const T& num) {
  key->append(std::to_string(num));
A
Adam 已提交
441 442
}

A
Adam 已提交
443 444
template <>
inline void AppendKey(std::string* key,
445
                      const dnnl::memory::format_tag& format) {
A
Adam 已提交
446 447 448 449 450
  key->append(std::to_string(static_cast<int>(format)));
}

template <>
inline void AppendKey(std::string* key,
451
                      const dnnl::memory::data_type& data_type) {
A
Adam 已提交
452 453 454 455
  key->append(std::to_string(static_cast<int>(data_type)));
}

template <>
456
inline void AppendKey(std::string* key, const dnnl::algorithm& algorithm) {
A
Adam 已提交
457 458 459 460 461
  key->append(std::to_string(static_cast<int>(algorithm)));
}

template <>
inline void AppendKey(std::string* key,
462
                      const dnnl::normalization_flags& flags) {
A
Adam 已提交
463 464 465
  key->append(std::to_string(static_cast<int>(flags)));
}

466 467
inline void AppendKey(std::string* key, const std::string& str) {
  key->append(str);
A
Adam 已提交
468 469
}

470
inline void AppendKey(std::string* key, const char* str) { key->append(str); }
A
Adam 已提交
471

A
Adam 已提交
472 473
template <typename T>
inline void AppendKey(std::string* key, const std::vector<T>& dims) {
474
  for (size_t i = 0; i < dims.size(); i++) {
A
Adam 已提交
475 476 477 478
    AppendKey(key, std::to_string(dims[i]));
  }
}

479 480 481 482
// If MKLDNN build and CPU place then register suffix in DeviceContext
inline void AttachPointerHashToMKLDNNKey(void* ptr,
                                         const platform::Place& place) {
  if (platform::is_cpu_place(place)) {
J
Jacek Czaja 已提交
483 484 485 486 487 488 489 490 491 492 493 494 495
    // Static vars will remember first executor and its thread
    // so both of them need to be processed by the same thread within
    // critical section
    static std::mutex static_vars_barrier;
    static_vars_barrier.lock();
    static auto first_exec = ptr;
    static auto first_thread = ThreadIDasStr();
    static_vars_barrier.unlock();

    if (first_exec != ptr) {
      paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix(
          "E" + std::to_string(reinterpret_cast<uintptr_t>(ptr)));
    }
496 497 498
    // Let's register adress of current executor
    paddle::platform::MKLDNNDeviceContext::tls().set_curr_exec(ptr);

J
Jacek Czaja 已提交
499 500 501 502
    // For first thread
    if (first_thread == ThreadIDasStr()) {
      paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key();
    }
503 504 505
  }
}

506
template <typename... ArgTypes>
507 508
inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx,
                             ArgTypes&&... args) {
509
  std::string key;
510
  key.reserve(64);
511
  using expand_type = int[];
512
  expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...};
J
Jacek Czaja 已提交
513
  key += paddle::platform::MKLDNNDeviceContext::tls().get_key_suffix();
514 515 516
  return key;
}

517 518
inline std::string ExtendKeyWithThreadInfoIfNeeded(
    const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) {
J
Jacek Czaja 已提交
519 520
  return (paddle::platform::MKLDNNDeviceContext::tls().is_tid_used_in_key() ==
          true)
521 522 523 524
             ? key + "-t:" + ThreadIDasStr()
             : key;
}

A
Adam 已提交
525 526
inline std::vector<std::vector<int64_t>> ToMkldnnPadding(
    const std::vector<int64_t>& paddings) {
527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546
  if (paddings.size() == 6) {
    int padding_front = paddings[0];
    int padding_back = paddings[1];
    int padding_top = paddings[2];
    int padding_bottom = paddings[3];
    int padding_left = paddings[4];
    int padding_right = paddings[5];

    return {{padding_front, padding_top, padding_left},
            {padding_back, padding_bottom, padding_right}};
  } else {
    int padding_top = paddings[0];
    int padding_bottom = paddings[1];
    int padding_left = paddings[2];
    int padding_right = paddings[3];

    return {{padding_top, padding_left}, {padding_bottom, padding_right}};
  }
}

547 548 549 550 551 552 553 554 555 556 557 558 559
// The function adjusts the vector of weight dimensions for group convolutions
inline void GetGroupConvWeightsTz(std::vector<int64_t>& weights_tz,  // NOLINT
                                  const int groups) {
  if (groups > 1) {
    // if (is_conv3d) [o, i, d, h, w]->[g, o/g, i, d, h, w]
    // else [o, i, h, w] -> [g, o/g, i, h, w]
    weights_tz.push_back(0);
    std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end());
    weights_tz[0] = groups;
    weights_tz[1] = weights_tz[1] / groups;
  }
}

560 561 562 563 564
inline bool HasOpINT8DataType(const paddle::framework::OpDesc* op) {
  return (op->GetAttrIfExists<std::string>("mkldnn_data_type") == "int8" ||
          op->GetAttrIfExists<bool>("use_quantizer"));
}

565 566 567 568 569 570 571
inline bool HasOpBFLOAT16DataType(const paddle::framework::OpDesc* op) {
  return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "bfloat16";
}

inline bool HasOpFLOAT32DataType(const paddle::framework::OpDesc* op) {
  return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "float32";
}
A
Adam Osewski 已提交
572

A
Adam 已提交
573 574
enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP };

A
Adam Osewski 已提交
575 576 577 578 579
template <typename T>
bool constexpr is_int8() {
  return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}

T
tensor-tang 已提交
580 581
}  // namespace platform
}  // namespace paddle