kernel_factory.cc 6.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/core/kernel_factory.h"
16 17

// See Note [ Why still include the fluid headers? ]
18
#include "paddle/phi/core/enforce.h"
19

20
namespace phi {
21 22 23 24 25 26 27 28 29 30

uint32_t KernelKey::Hash::operator()(const KernelKey& key) const {
  uint32_t hash_value = 0;
  // |----31-20------|---19-12---|---11-8----|---7-0---|
  // | For extension | DataType | DataLayout | Backend |
  hash_value |= static_cast<uint8_t>(key.backend());
  hash_value |=
      (static_cast<uint8_t>(key.layout()) << KernelKey::kBackendBitLength);
  hash_value |=
      (static_cast<uint16_t>(key.dtype())
31
       << (KernelKey::kBackendBitLength + KernelKey::kDataLayoutBitLength));
32 33 34 35 36 37 38 39
  return hash_value;
}

KernelFactory& KernelFactory::Instance() {
  static KernelFactory g_op_kernel_factory;
  return g_op_kernel_factory;
}

Y
YuanRisheng 已提交
40
Kernel KernelFactory::SelectKernel(const std::string& kernel_name,
41 42 43 44 45 46 47 48 49 50 51 52
                                   const KernelKey& kernel_key) const {
  auto iter = kernels_.find(kernel_name);
  if (iter == kernels_.end()) {
    return Kernel();
  }
  auto kernel_iter = iter->second.find(kernel_key);
  if (kernel_iter == iter->second.end()) {
    return Kernel();
  }
  return kernel_iter->second;
}

53 54
KernelKeyMap KernelFactory::SelectKernelMap(
    const std::string& kernel_name) const {
55 56
  auto iter = kernels_.find(kernel_name);
  if (iter == kernels_.end()) {
57
    return KernelKeyMap();
58 59 60 61
  }
  return iter->second;
}

62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
bool KernelFactory::IsSelectKernelValid(const std::string& kernel_name,
                                        const KernelKey& kernel_key) const {
  auto iter = kernels_.find(kernel_name);
  PADDLE_ENFORCE_NE(
      iter,
      kernels_.end(),
      phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));

  auto kernel_iter = iter->second.find(kernel_key);
  if (kernel_iter == iter->second.end()) {
    return false;
  }
  return true;
}

77
const Kernel& KernelFactory::SelectKernelOrThrowError(
Z
zyfncg 已提交
78 79 80
    const std::string& kernel_name,
    const KernelKey& kernel_key,
    bool use_cudnn) const {
81
  auto iter = kernels_.find(kernel_name);
82 83 84 85
  PADDLE_ENFORCE_NE(
      iter,
      kernels_.end(),
      phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
86

Z
zyfncg 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  if (use_cudnn && kernel_key.backend() == Backend::GPU) {
    auto kernel_iter = iter->second.find(
        {Backend::GPUDNN, kernel_key.layout(), kernel_key.dtype()});
    if (kernel_iter == iter->second.end() &&
        kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
      kernel_iter = iter->second.find(
          {Backend::GPUDNN, DataLayout::ALL_LAYOUT, kernel_key.dtype()});
    }
    if (kernel_iter != iter->second.end()) {
      return kernel_iter->second;
    }
    LOG(WARNING) << "The cudnn kernel for [" << kernel_name
                 << "] is not registered.";
  }
#endif
103 104
  auto kernel_iter = iter->second.find(kernel_key);
  // TODO(chenweihang): polish refind impl here
105
  if (kernel_iter == iter->second.end() &&
106 107 108
      kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
    phi::KernelKey any_layout_kernel_key(
        kernel_key.backend(), phi::DataLayout::ALL_LAYOUT, kernel_key.dtype());
109 110 111 112 113
    kernel_iter = iter->second.find(any_layout_kernel_key);
  }
  PADDLE_ENFORCE_NE(
      kernel_iter,
      iter->second.end(),
114
      phi::errors::NotFound(
115 116 117 118 119 120 121 122
          "The kernel with key %s of kernel `%s` is not registered.",
          kernel_key,
          kernel_name));

  return kernel_iter->second;
}

const Kernel& KernelFactory::SelectKernelOrThrowError(
Y
YuanRisheng 已提交
123
    const std::string& kernel_name,
124 125 126 127 128 129 130
    Backend backend,
    DataLayout layout,
    DataType dtype) const {
  return SelectKernelOrThrowError(kernel_name,
                                  KernelKey(backend, layout, dtype));
}

131 132 133 134 135 136 137
// print kernel info with json format:
// {
//   "(CPU, Undefined(AnyLayout), complex64)": {
//   "input": ["CPU, NCHW, complex64", "CPU, NCHW, complex64"],
//   "output": ["CPU, NCHW, complex64"],
//   "attribute": ["i"]
// }
138
std::ostream& operator<<(std::ostream& os, const Kernel& kernel) {
139 140 141
  // input
  os << "{\"input\":[";
  bool need_comma = false;
142
  for (auto& in_def : kernel.args_def().input_defs()) {
143 144 145 146
    if (need_comma) os << ",";
    os << "\"" << in_def.backend << ", " << in_def.layout << ", "
       << in_def.dtype << "\"";
    need_comma = true;
147
  }
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
  os << "],";

  // output
  os << "\"output\":[";
  need_comma = false;
  for (auto& out_def : kernel.args_def().output_defs()) {
    if (need_comma) os << ",";
    os << "\"" << out_def.backend << ", " << out_def.layout << ", "
       << out_def.dtype << "\"";
    need_comma = true;
  }
  os << "],";

  // attr
  os << "\"attribute\":[";
  need_comma = false;
  for (auto& arg_def : kernel.args_def().attribute_defs()) {
    if (need_comma) os << ",";
    os << "\"" << arg_def.type_index.name() << "\"";
    need_comma = true;
  }
  os << "]}";

171 172 173
  return os;
}

174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
// print all kernels info with json format:
// {
//  "kernel_name1":
//      [
//        {
//          "(CPU, Undefined(AnyLayout), complex64)": {
//          "input": ["CPU, NCHW, complex64", "CPU, NCHW, complex64"],
//          "output": ["CPU, NCHW, complex64"],
//          "attribute": ["i"]
//        },
//        ...
//      ],
//    "kernel_name2": []
//    ...
// }
189
std::ostream& operator<<(std::ostream& os, KernelFactory& kernel_factory) {
190 191
  os << "{";
  bool need_comma_kernels = false;
192
  for (const auto& op_kernel_pair : kernel_factory.kernels()) {
193 194 195
    if (need_comma_kernels) os << ",";
    os << "\"" << op_kernel_pair.first << "\":[";
    bool need_comma_per_kernel = false;
196
    for (const auto& kernel_pair : op_kernel_pair.second) {
197 198 199
      if (need_comma_per_kernel) os << ",";
      os << "{\"" << kernel_pair.first << "\":" << kernel_pair.second << "}";
      need_comma_per_kernel = true;
200
    }
201 202
    os << "]";
    need_comma_kernels = true;
203
  }
204 205
  os << "}";

206 207 208
  return os;
}

209
}  // namespace phi