kernel_factory.cc 6.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/core/kernel_factory.h"
16 17

// See Note [ Why still include the fluid headers? ]
18
#include "paddle/phi/core/enforce.h"
19

20
namespace phi {
21

22 23
const static Kernel empty_kernel;  // NOLINT

24 25 26 27 28 29 30 31 32
uint32_t KernelKey::Hash::operator()(const KernelKey& key) const {
  uint32_t hash_value = 0;
  // |----31-20------|---19-12---|---11-8----|---7-0---|
  // | For extension | DataType | DataLayout | Backend |
  hash_value |= static_cast<uint8_t>(key.backend());
  hash_value |=
      (static_cast<uint8_t>(key.layout()) << KernelKey::kBackendBitLength);
  hash_value |=
      (static_cast<uint16_t>(key.dtype())
33
       << (KernelKey::kBackendBitLength + KernelKey::kDataLayoutBitLength));
34 35 36 37 38 39 40 41
  return hash_value;
}

KernelFactory& KernelFactory::Instance() {
  static KernelFactory g_op_kernel_factory;
  return g_op_kernel_factory;
}

42 43
const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name,
                                          const KernelKey& kernel_key) const {
44 45
  auto iter = kernels_.find(kernel_name);
  if (iter == kernels_.end()) {
46
    return empty_kernel;
47 48 49
  }
  auto kernel_iter = iter->second.find(kernel_key);
  if (kernel_iter == iter->second.end()) {
50
    return empty_kernel;
51 52 53 54
  }
  return kernel_iter->second;
}

55 56
KernelKeyMap KernelFactory::SelectKernelMap(
    const std::string& kernel_name) const {
57 58
  auto iter = kernels_.find(kernel_name);
  if (iter == kernels_.end()) {
59
    return KernelKeyMap();
60 61 62 63
  }
  return iter->second;
}

64 65
bool KernelFactory::HasKernel(const std::string& kernel_name,
                              const KernelKey& kernel_key) const {
66 67 68 69 70 71 72 73 74 75 76 77 78
  auto iter = kernels_.find(kernel_name);
  PADDLE_ENFORCE_NE(
      iter,
      kernels_.end(),
      phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));

  auto kernel_iter = iter->second.find(kernel_key);
  if (kernel_iter == iter->second.end()) {
    return false;
  }
  return true;
}

79
const Kernel& KernelFactory::SelectKernelOrThrowError(
Z
zyfncg 已提交
80 81 82
    const std::string& kernel_name,
    const KernelKey& kernel_key,
    bool use_cudnn) const {
83
  auto iter = kernels_.find(kernel_name);
84 85 86 87
  PADDLE_ENFORCE_NE(
      iter,
      kernels_.end(),
      phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
88

Z
zyfncg 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  if (use_cudnn && kernel_key.backend() == Backend::GPU) {
    auto kernel_iter = iter->second.find(
        {Backend::GPUDNN, kernel_key.layout(), kernel_key.dtype()});
    if (kernel_iter == iter->second.end() &&
        kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
      kernel_iter = iter->second.find(
          {Backend::GPUDNN, DataLayout::ALL_LAYOUT, kernel_key.dtype()});
    }
    if (kernel_iter != iter->second.end()) {
      return kernel_iter->second;
    }
    LOG(WARNING) << "The cudnn kernel for [" << kernel_name
                 << "] is not registered.";
  }
#endif
105 106
  auto kernel_iter = iter->second.find(kernel_key);
  // TODO(chenweihang): polish refind impl here
107
  if (kernel_iter == iter->second.end() &&
108 109 110
      kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
    phi::KernelKey any_layout_kernel_key(
        kernel_key.backend(), phi::DataLayout::ALL_LAYOUT, kernel_key.dtype());
111 112 113 114 115
    kernel_iter = iter->second.find(any_layout_kernel_key);
  }
  PADDLE_ENFORCE_NE(
      kernel_iter,
      iter->second.end(),
116
      phi::errors::NotFound(
117 118 119 120 121 122 123 124
          "The kernel with key %s of kernel `%s` is not registered.",
          kernel_key,
          kernel_name));

  return kernel_iter->second;
}

const Kernel& KernelFactory::SelectKernelOrThrowError(
Y
YuanRisheng 已提交
125
    const std::string& kernel_name,
126 127 128 129 130 131 132
    Backend backend,
    DataLayout layout,
    DataType dtype) const {
  return SelectKernelOrThrowError(kernel_name,
                                  KernelKey(backend, layout, dtype));
}

133 134 135 136 137 138 139 140 141 142
const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef(
    const std::string& kernel_name) const {
  auto iter = kernels_.find(kernel_name);
  PADDLE_ENFORCE_NE(
      iter,
      kernels_.end(),
      phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
  return iter->second.cbegin()->second.args_def();
}

143 144 145 146 147 148 149
// print kernel info with json format:
// {
//   "(CPU, Undefined(AnyLayout), complex64)": {
//   "input": ["CPU, NCHW, complex64", "CPU, NCHW, complex64"],
//   "output": ["CPU, NCHW, complex64"],
//   "attribute": ["i"]
// }
150
std::ostream& operator<<(std::ostream& os, const Kernel& kernel) {
151 152 153
  // input
  os << "{\"input\":[";
  bool need_comma = false;
154
  for (auto& in_def : kernel.args_def().input_defs()) {
155 156 157 158
    if (need_comma) os << ",";
    os << "\"" << in_def.backend << ", " << in_def.layout << ", "
       << in_def.dtype << "\"";
    need_comma = true;
159
  }
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
  os << "],";

  // output
  os << "\"output\":[";
  need_comma = false;
  for (auto& out_def : kernel.args_def().output_defs()) {
    if (need_comma) os << ",";
    os << "\"" << out_def.backend << ", " << out_def.layout << ", "
       << out_def.dtype << "\"";
    need_comma = true;
  }
  os << "],";

  // attr
  os << "\"attribute\":[";
  need_comma = false;
  for (auto& arg_def : kernel.args_def().attribute_defs()) {
    if (need_comma) os << ",";
    os << "\"" << arg_def.type_index.name() << "\"";
    need_comma = true;
  }
  os << "]}";

183 184 185
  return os;
}

186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
// print all kernels info with json format:
// {
//  "kernel_name1":
//      [
//        {
//          "(CPU, Undefined(AnyLayout), complex64)": {
//          "input": ["CPU, NCHW, complex64", "CPU, NCHW, complex64"],
//          "output": ["CPU, NCHW, complex64"],
//          "attribute": ["i"]
//        },
//        ...
//      ],
//    "kernel_name2": []
//    ...
// }
201
std::ostream& operator<<(std::ostream& os, KernelFactory& kernel_factory) {
202 203
  os << "{";
  bool need_comma_kernels = false;
204
  for (const auto& op_kernel_pair : kernel_factory.kernels()) {
205 206 207
    if (need_comma_kernels) os << ",";
    os << "\"" << op_kernel_pair.first << "\":[";
    bool need_comma_per_kernel = false;
208
    for (const auto& kernel_pair : op_kernel_pair.second) {
209 210 211
      if (need_comma_per_kernel) os << ",";
      os << "{\"" << kernel_pair.first << "\":" << kernel_pair.second << "}";
      need_comma_per_kernel = true;
212
    }
213 214
    os << "]";
    need_comma_kernels = true;
215
  }
216 217
  os << "}";

218 219 220
  return os;
}

221
}  // namespace phi