onednn_helper.h 5.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "dnnl.hpp"  // NOLINT
#include "glog/logging.h"

#include "paddle/phi/backends/onednn/onednn_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"

namespace phi {
namespace funcs {

28 29
using OneDNNMemoryFormat = dnnl::memory::format_tag;
using OneDNNDataType = dnnl::memory::data_type;
30 31 32 33 34 35

template <typename Type>
void* to_void_cast(const Type* t) {
  return static_cast<void*>(const_cast<Type*>(t));
}

36 37
inline OneDNNMemoryFormat OneDNNFormatForSize(size_t dims_size,
                                              OneDNNMemoryFormat data_format) {
38
  if (dims_size == 1) {
39
    return OneDNNMemoryFormat::x;
40
  } else if (dims_size == 2) {
41
    return OneDNNMemoryFormat::nc;
42
  } else if (dims_size == 3) {
43 44 45 46
    if (data_format == OneDNNMemoryFormat::nchw) {
      return OneDNNMemoryFormat::ncw;
    } else if (data_format == OneDNNMemoryFormat::nhwc) {
      return OneDNNMemoryFormat::nwc;
47 48
    }
  } else if (dims_size == 4) {
49 50
    if (data_format == OneDNNMemoryFormat::goihw) {
      return OneDNNMemoryFormat::oihw;
51 52
    }
  } else if (dims_size == 5) {
53 54
    if (data_format == OneDNNMemoryFormat::goidhw) {
      return OneDNNMemoryFormat::oidhw;
55
    }
56 57 58 59
    if (data_format == OneDNNMemoryFormat::nchw) {
      return OneDNNMemoryFormat::ncdhw;
    } else if (data_format == OneDNNMemoryFormat::nhwc) {
      return OneDNNMemoryFormat::ndhwc;
60 61
    }
  } else if (dims_size == 6) {
62 63
    if (data_format == OneDNNMemoryFormat::nchw) {
      return OneDNNMemoryFormat::abcdef;
64 65 66 67 68
    }
  }
  return data_format;
}

69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
inline dnnl::memory::format_tag GetPlainOneDNNFormat(int tensor_rank) {
  switch (tensor_rank) {
    case 1:
      return dnnl::memory::format_tag::a;
    case 2:
      return dnnl::memory::format_tag::ab;
    case 3:
      return dnnl::memory::format_tag::abc;
    case 4:
      return dnnl::memory::format_tag::abcd;
    case 5:
      return dnnl::memory::format_tag::abcde;
    case 6:
      return dnnl::memory::format_tag::abcdef;
    case 7:
      return dnnl::memory::format_tag::abcdefg;
    case 8:
      return dnnl::memory::format_tag::abcdefgh;
    case 9:
      return dnnl::memory::format_tag::abcdefghi;
    default:
      PADDLE_THROW(phi::errors::Unimplemented(
          "Paddle support tensors with rank in range <1, 9>, but received "
          "tensor with rank: %d",
          tensor_rank));
  }
}

97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
inline void MatchShapeToLayout(DenseTensor* tensor_in,
                               DataLayout from,
                               DataLayout to) {
  auto print_dims = [](const std::vector<int>& dims) {
    std::ostringstream oss;

    if (!dims.empty()) {
      oss << "[";
      // Convert all but the last element to avoid a trailing ","
      std::copy(
          dims.begin(), dims.end() - 1, std::ostream_iterator<int>(oss, ","));

      // Now add the last element with no delimiter
      oss << dims.back() << "]";
    }

    return oss.str();
  };

  // In these data layouts, channel dimension is either on 2nd position: nChw or
  // at last nhwC, so for dim==2 these layouts are the same and nothing should
  // be done. Similarly for dim==1 when you have just one possible combination.
  if (tensor_in->dims().size() < 3) {
    VLOG(3) << "Keeping MKLDNN/NHWC/NDHWC output_shape"
            << print_dims(phi::vectorize<int>(tensor_in->dims()));
    return;
  }

  switch (from) {
    case DataLayout::MKLDNN:
      if ((to == DataLayout::NHWC) || (to == DataLayout::NDHWC)) {
        auto dims = phi::vectorize<int>(tensor_in->dims());
        std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
        tensor_in->Resize(phi::make_ddim(dims));
        VLOG(3) << "Rotating Shape from: MKLDNN to: NHWC/NDHWC output_shape"
                << print_dims(dims);
      }
      break;
    case DataLayout::NHWC:
    case DataLayout::NDHWC:
      if (to == DataLayout::MKLDNN) {
        auto dims = phi::vectorize<int>(tensor_in->dims());
        std::rotate(dims.begin() + 1, dims.end() - 1, dims.end());
        tensor_in->Resize(phi::make_ddim(dims));
        VLOG(3) << "Rotating Shape from: NHWC/NDHWC to: MKLDNN output_shape"
                << print_dims(dims);
      }
      break;
    default:
      break;
  }
}

150
struct onednn_dummy_primitive {
151 152 153 154
  struct primitive_desc {};
  struct desc {};
};

155
inline dnnl::memory::desc OneDNNMemDesc(const std::vector<int64_t>& dims,
156
                                        dnnl::memory::data_type data_type,
157
                                        OneDNNMemoryFormat format) {
158 159 160 161 162
  return dnnl::memory::desc({dims}, data_type, format);
}

}  // namespace funcs
}  // namespace phi