kernel_factory.cc 5.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/core/kernel_factory.h"
16 17

// See Note [ Why still include the fluid headers? ]
18
#include "paddle/phi/core/enforce.h"
19

20
namespace phi {
21 22 23 24 25 26 27 28 29 30

uint32_t KernelKey::Hash::operator()(const KernelKey& key) const {
  uint32_t hash_value = 0;
  // |----31-20------|---19-12---|---11-8----|---7-0---|
  // | For extension | DataType | DataLayout | Backend |
  hash_value |= static_cast<uint8_t>(key.backend());
  hash_value |=
      (static_cast<uint8_t>(key.layout()) << KernelKey::kBackendBitLength);
  hash_value |=
      (static_cast<uint16_t>(key.dtype())
31
       << (KernelKey::kBackendBitLength + KernelKey::kDataLayoutBitLength));
32 33 34 35 36 37 38 39
  return hash_value;
}

KernelFactory& KernelFactory::Instance() {
  static KernelFactory g_op_kernel_factory;
  return g_op_kernel_factory;
}

Y
YuanRisheng 已提交
40
Kernel KernelFactory::SelectKernel(const std::string& kernel_name,
41 42 43 44 45 46 47 48 49 50 51 52
                                   const KernelKey& kernel_key) const {
  auto iter = kernels_.find(kernel_name);
  if (iter == kernels_.end()) {
    return Kernel();
  }
  auto kernel_iter = iter->second.find(kernel_key);
  if (kernel_iter == iter->second.end()) {
    return Kernel();
  }
  return kernel_iter->second;
}

53 54
KernelKeyMap KernelFactory::SelectKernelMap(
    const std::string& kernel_name) const {
55 56
  auto iter = kernels_.find(kernel_name);
  if (iter == kernels_.end()) {
57
    return KernelKeyMap();
58 59 60 61
  }
  return iter->second;
}

62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
bool KernelFactory::IsSelectKernelValid(const std::string& kernel_name,
                                        const KernelKey& kernel_key) const {
  auto iter = kernels_.find(kernel_name);
  PADDLE_ENFORCE_NE(
      iter,
      kernels_.end(),
      phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));

  auto kernel_iter = iter->second.find(kernel_key);
  if (kernel_iter == iter->second.end()) {
    return false;
  }
  return true;
}

77
const Kernel& KernelFactory::SelectKernelOrThrowError(
Y
YuanRisheng 已提交
78
    const std::string& kernel_name, const KernelKey& kernel_key) const {
79
  auto iter = kernels_.find(kernel_name);
80 81 82 83
  PADDLE_ENFORCE_NE(
      iter,
      kernels_.end(),
      phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
84 85 86

  auto kernel_iter = iter->second.find(kernel_key);
  // TODO(chenweihang): polish refind impl here
87
  if (kernel_iter == iter->second.end() &&
88 89 90
      kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
    phi::KernelKey any_layout_kernel_key(
        kernel_key.backend(), phi::DataLayout::ALL_LAYOUT, kernel_key.dtype());
91 92 93 94 95
    kernel_iter = iter->second.find(any_layout_kernel_key);
  }
  PADDLE_ENFORCE_NE(
      kernel_iter,
      iter->second.end(),
96
      phi::errors::NotFound(
97 98 99 100 101 102 103 104
          "The kernel with key %s of kernel `%s` is not registered.",
          kernel_key,
          kernel_name));

  return kernel_iter->second;
}

const Kernel& KernelFactory::SelectKernelOrThrowError(
Y
YuanRisheng 已提交
105
    const std::string& kernel_name,
106 107 108 109 110 111 112
    Backend backend,
    DataLayout layout,
    DataType dtype) const {
  return SelectKernelOrThrowError(kernel_name,
                                  KernelKey(backend, layout, dtype));
}

113 114 115 116 117 118 119
// print kernel info with json format:
// {
//   "(CPU, Undefined(AnyLayout), complex64)": {
//   "input": ["CPU, NCHW, complex64", "CPU, NCHW, complex64"],
//   "output": ["CPU, NCHW, complex64"],
//   "attribute": ["i"]
// }
120
std::ostream& operator<<(std::ostream& os, const Kernel& kernel) {
121 122 123
  // input
  os << "{\"input\":[";
  bool need_comma = false;
124
  for (auto& in_def : kernel.args_def().input_defs()) {
125 126 127 128
    if (need_comma) os << ",";
    os << "\"" << in_def.backend << ", " << in_def.layout << ", "
       << in_def.dtype << "\"";
    need_comma = true;
129
  }
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
  os << "],";

  // output
  os << "\"output\":[";
  need_comma = false;
  for (auto& out_def : kernel.args_def().output_defs()) {
    if (need_comma) os << ",";
    os << "\"" << out_def.backend << ", " << out_def.layout << ", "
       << out_def.dtype << "\"";
    need_comma = true;
  }
  os << "],";

  // attr
  os << "\"attribute\":[";
  need_comma = false;
  for (auto& arg_def : kernel.args_def().attribute_defs()) {
    if (need_comma) os << ",";
    os << "\"" << arg_def.type_index.name() << "\"";
    need_comma = true;
  }
  os << "]}";

153 154 155
  return os;
}

156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
// print all kernels info with json format:
// {
//  "kernel_name1":
//      [
//        {
//          "(CPU, Undefined(AnyLayout), complex64)": {
//          "input": ["CPU, NCHW, complex64", "CPU, NCHW, complex64"],
//          "output": ["CPU, NCHW, complex64"],
//          "attribute": ["i"]
//        },
//        ...
//      ],
//    "kernel_name2": []
//    ...
// }
171
std::ostream& operator<<(std::ostream& os, KernelFactory& kernel_factory) {
172 173
  os << "{";
  bool need_comma_kernels = false;
174
  for (const auto& op_kernel_pair : kernel_factory.kernels()) {
175 176 177
    if (need_comma_kernels) os << ",";
    os << "\"" << op_kernel_pair.first << "\":[";
    bool need_comma_per_kernel = false;
178
    for (const auto& kernel_pair : op_kernel_pair.second) {
179 180 181
      if (need_comma_per_kernel) os << ",";
      os << "{\"" << kernel_pair.first << "\":" << kernel_pair.second << "}";
      need_comma_per_kernel = true;
182
    }
183 184
    os << "]";
    need_comma_kernels = true;
185
  }
186 187
  os << "}";

188 189 190
  return os;
}

191
}  // namespace phi