kernel_factory.cc 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/core/kernel_factory.h"
16 17

// See Note [ Why still include the fluid headers? ]
18
#include "paddle/phi/core/enforce.h"
19

20
namespace phi {
21 22 23 24 25 26 27 28 29 30

uint32_t KernelKey::Hash::operator()(const KernelKey& key) const {
  uint32_t hash_value = 0;
  // |----31-20------|---19-12---|---11-8----|---7-0---|
  // | For extension | DataType | DataLayout | Backend |
  hash_value |= static_cast<uint8_t>(key.backend());
  hash_value |=
      (static_cast<uint8_t>(key.layout()) << KernelKey::kBackendBitLength);
  hash_value |=
      (static_cast<uint16_t>(key.dtype())
31
       << (KernelKey::kBackendBitLength + KernelKey::kDataLayoutBitLength));
32 33 34 35 36 37 38 39
  return hash_value;
}

KernelFactory& KernelFactory::Instance() {
  static KernelFactory g_op_kernel_factory;
  return g_op_kernel_factory;
}

Y
YuanRisheng 已提交
40
Kernel KernelFactory::SelectKernel(const std::string& kernel_name,
41 42 43 44 45 46 47 48 49 50 51 52
                                   const KernelKey& kernel_key) const {
  auto iter = kernels_.find(kernel_name);
  if (iter == kernels_.end()) {
    return Kernel();
  }
  auto kernel_iter = iter->second.find(kernel_key);
  if (kernel_iter == iter->second.end()) {
    return Kernel();
  }
  return kernel_iter->second;
}

53 54
KernelKeyMap KernelFactory::SelectKernelMap(
    const std::string& kernel_name) const {
55 56
  auto iter = kernels_.find(kernel_name);
  if (iter == kernels_.end()) {
57
    return KernelKeyMap();
58 59 60 61
  }
  return iter->second;
}

62
const Kernel& KernelFactory::SelectKernelOrThrowError(
Y
YuanRisheng 已提交
63
    const std::string& kernel_name, const KernelKey& kernel_key) const {
64
  auto iter = kernels_.find(kernel_name);
65 66 67 68
  PADDLE_ENFORCE_NE(
      iter,
      kernels_.end(),
      phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
69 70 71

  auto kernel_iter = iter->second.find(kernel_key);
  // TODO(chenweihang): polish refind impl here
72
  if (kernel_iter == iter->second.end() &&
73 74 75
      kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
    phi::KernelKey any_layout_kernel_key(
        kernel_key.backend(), phi::DataLayout::ALL_LAYOUT, kernel_key.dtype());
76 77 78 79 80
    kernel_iter = iter->second.find(any_layout_kernel_key);
  }
  PADDLE_ENFORCE_NE(
      kernel_iter,
      iter->second.end(),
81
      phi::errors::NotFound(
82 83 84 85 86 87 88 89
          "The kernel with key %s of kernel `%s` is not registered.",
          kernel_key,
          kernel_name));

  return kernel_iter->second;
}

const Kernel& KernelFactory::SelectKernelOrThrowError(
Y
YuanRisheng 已提交
90
    const std::string& kernel_name,
91 92 93 94 95 96 97
    Backend backend,
    DataLayout layout,
    DataType dtype) const {
  return SelectKernelOrThrowError(kernel_name,
                                  KernelKey(backend, layout, dtype));
}

98 99 100 101 102 103 104
// print kernel info with json format:
// {
//   "(CPU, Undefined(AnyLayout), complex64)": {
//   "input": ["CPU, NCHW, complex64", "CPU, NCHW, complex64"],
//   "output": ["CPU, NCHW, complex64"],
//   "attribute": ["i"]
// }
105
std::ostream& operator<<(std::ostream& os, const Kernel& kernel) {
106 107 108
  // input
  os << "{\"input\":[";
  bool need_comma = false;
109
  for (auto& in_def : kernel.args_def().input_defs()) {
110 111 112 113
    if (need_comma) os << ",";
    os << "\"" << in_def.backend << ", " << in_def.layout << ", "
       << in_def.dtype << "\"";
    need_comma = true;
114
  }
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
  os << "],";

  // output
  os << "\"output\":[";
  need_comma = false;
  for (auto& out_def : kernel.args_def().output_defs()) {
    if (need_comma) os << ",";
    os << "\"" << out_def.backend << ", " << out_def.layout << ", "
       << out_def.dtype << "\"";
    need_comma = true;
  }
  os << "],";

  // attr
  os << "\"attribute\":[";
  need_comma = false;
  for (auto& arg_def : kernel.args_def().attribute_defs()) {
    if (need_comma) os << ",";
    os << "\"" << arg_def.type_index.name() << "\"";
    need_comma = true;
  }
  os << "]}";

138 139 140
  return os;
}

141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
// print all kernels info with json format:
// {
//  "kernel_name1":
//      [
//        {
//          "(CPU, Undefined(AnyLayout), complex64)": {
//          "input": ["CPU, NCHW, complex64", "CPU, NCHW, complex64"],
//          "output": ["CPU, NCHW, complex64"],
//          "attribute": ["i"]
//        },
//        ...
//      ],
//    "kernel_name2": []
//    ...
// }
156
std::ostream& operator<<(std::ostream& os, KernelFactory& kernel_factory) {
157 158
  os << "{";
  bool need_comma_kernels = false;
159
  for (const auto& op_kernel_pair : kernel_factory.kernels()) {
160 161 162
    if (need_comma_kernels) os << ",";
    os << "\"" << op_kernel_pair.first << "\":[";
    bool need_comma_per_kernel = false;
163
    for (const auto& kernel_pair : op_kernel_pair.second) {
164 165 166
      if (need_comma_per_kernel) os << ",";
      os << "{\"" << kernel_pair.first << "\":" << kernel_pair.second << "}";
      need_comma_per_kernel = true;
167
    }
168 169
    os << "]";
    need_comma_kernels = true;
170
  }
171 172
  os << "}";

173 174 175
  return os;
}

176
}  // namespace phi