kernel_factory.h 10.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <map>
18 19 20 21 22
#include <ostream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
23 24 25 26
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/compat/convert_utils.h"
27
#include "paddle/phi/core/compat/get_kerneltype_forvar_utils.h"
28 29
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/type_defs.h"
30
#include "paddle/phi/core/utils/data_type.h"
31 32
#include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h"
33
namespace phi {
34

35 36 37 38 39 40 41 42 43 44 45 46 47
struct OpCount {
  OpCount() {
    fp16_called_ = 0;
    bf16_called_ = 0;
    fp32_called_ = 0;
    other_called_ = 0;
  }
  int fp16_called_;
  int bf16_called_;
  int fp32_called_;
  int other_called_;
};

48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
/**
 * [ Naming considerations ]
 *
 * The tensor operation library contains many kernels, and the computation
 * in each specific scenario is represented by an kernel.
 *
 * We directly named it `Kernel` instead of `Kernel`, the tensor operation
 * library here and fluid are independent, avoiding developers from
 * misunderstanding the relationship between the two concepts.
 */

class KernelContext;

class KernelKey {
 public:
  KernelKey() = default;

  KernelKey(Backend backend, DataLayout layout, DataType dtype)
      : backend_(backend), layout_(layout), dtype_(dtype) {}

68
  explicit KernelKey(const Place& place)
69 70 71 72
      : backend_(TransToPhiBackend(place)),
        layout_(DataLayout::ALL_LAYOUT),
        dtype_(DataType::ALL_DTYPE) {}

73
  explicit KernelKey(const int& dtype, const Place& place)
74 75 76 77
      : backend_(TransToPhiBackend(place)),
        layout_(DataLayout::ALL_LAYOUT),
        dtype_(phi::TransToPhiDataType(dtype)) {}

78 79 80
  explicit KernelKey(const Place& place,
                     const DataLayout& layout,
                     const DataType& dtype)
81 82
      : backend_(TransToPhiBackend(place)), layout_(layout), dtype_(dtype) {}

83 84 85 86
  Backend backend() const { return backend_; }
  DataLayout layout() const { return layout_; }
  DataType dtype() const { return dtype_; }

87 88 89 90
  void set_backend(const Backend& backend) { backend_ = backend; }
  void set_layout(const DataLayout& layout) { layout_ = layout; }
  void set_dtype(const DataType& dtype) { dtype_ = dtype; }

91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
  struct Hash {
    // Note: Now the number of bits we need does not exceed 32 bits, so there is
    // no need to use 64 bits. If needed in the future, it can be expanded,
    // but now we don’t over-design.
    uint32_t operator()(const KernelKey& key) const;
  };

  uint32_t hash_value() const { return Hash()(*this); }

  bool operator<(const KernelKey& key) const {
    return hash_value() < key.hash_value();
  }

  bool operator==(const KernelKey& key) const {
    return hash_value() == key.hash_value();
  }

  bool operator!=(const KernelKey& key) const {
    return hash_value() != key.hash_value();
  }

 private:
  // In total should be smaller than 32.
  constexpr static int kBackendBitLength = 8;
  constexpr static int kDataLayoutBitLength = 4;
  constexpr static int kDataTypeBitLength = 8;

  Backend backend_{Backend::UNDEFINED};
  DataLayout layout_{DataLayout::UNDEFINED};
  DataType dtype_{DataType::UNDEFINED};
};

// TODO(chenweihang): how deal with vector<Param>?
struct TensorArgDef {
  Backend backend;
  DataLayout layout;
  DataType dtype;
H
hong 已提交
128
  std::type_index type_index;
129

H
hong 已提交
130 131 132 133 134 135 136 137
  TensorArgDef(Backend in_backend,
               DataLayout in_layout,
               DataType in_dtype,
               std::type_index in_type_index)
      : backend(in_backend),
        layout(in_layout),
        dtype(in_dtype),
        type_index(in_type_index) {}
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154

  TensorArgDef& SetBackend(Backend in_backend) {
    backend = in_backend;
    return *this;
  }

  TensorArgDef& SetDataLayout(DataLayout in_layout) {
    layout = in_layout;
    return *this;
  }

  TensorArgDef& SetDataType(DataType in_dtype) {
    dtype = in_dtype;
    return *this;
  }
};

155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
// Align the original fluid Attribute type with lower overhead
enum class AttributeType {
  UNDEFINED = 0,
  BOOL,
  INT32,
  INT64,
  FLOAT32,
  FLOAT64,
  STRING,
  BOOLS,
  INT32S,
  INT64S,
  FLOAT32S,
  FLOAT64S,
  STRINGS,
  SCALAR,
  SCALARS,
  INT_ARRAY,
  DATA_TYPE,
  DATA_LAYOUT,
175
  PLACE
176 177
};

178
struct AttributeArgDef {
179
  AttributeType type_index;
180

181
  explicit AttributeArgDef(AttributeType type_index) : type_index(type_index) {}
182 183 184 185 186 187
};

class KernelArgsDef {
 public:
  KernelArgsDef() = default;

H
hong 已提交
188 189 190 191 192
  void AppendInput(Backend backend,
                   DataLayout layout,
                   DataType dtype,
                   std::type_index type_index) {
    input_defs_.emplace_back(TensorArgDef(backend, layout, dtype, type_index));
193 194
  }

H
hong 已提交
195 196 197 198 199
  void AppendOutput(Backend backend,
                    DataLayout layout,
                    DataType dtype,
                    std::type_index type_index) {
    output_defs_.emplace_back(TensorArgDef(backend, layout, dtype, type_index));
200 201
  }

202
  void AppendAttribute(AttributeType type_index) {
203 204 205
    attribute_defs_.emplace_back(AttributeArgDef(type_index));
  }

C
Chen Weihang 已提交
206
  const paddle::small_vector<TensorArgDef, kInputSmallVectorSize>& input_defs()
207
      const {
208 209 210
    return input_defs_;
  }

C
Chen Weihang 已提交
211 212
  const paddle::small_vector<TensorArgDef, kOutputSmallVectorSize>&
  output_defs() const {
213 214 215
    return output_defs_;
  }

C
Chen Weihang 已提交
216
  const paddle::small_vector<AttributeArgDef, kAttrSmallVectorSize>&
217
  attribute_defs() const {
218 219 220
    return attribute_defs_;
  }

C
Chen Weihang 已提交
221
  paddle::small_vector<TensorArgDef, kInputSmallVectorSize>& input_defs() {
222 223
    return input_defs_;
  }
224

C
Chen Weihang 已提交
225
  paddle::small_vector<TensorArgDef, kOutputSmallVectorSize>& output_defs() {
226 227
    return output_defs_;
  }
228

C
Chen Weihang 已提交
229 230
  paddle::small_vector<AttributeArgDef, kAttrSmallVectorSize>&
  attribute_defs() {
231 232 233 234
    return attribute_defs_;
  }

 private:
C
Chen Weihang 已提交
235 236 237
  paddle::small_vector<TensorArgDef, kInputSmallVectorSize> input_defs_{{}};
  paddle::small_vector<TensorArgDef, kOutputSmallVectorSize> output_defs_{{}};
  paddle::small_vector<AttributeArgDef, kAttrSmallVectorSize> attribute_defs_{
238
      {}};
239 240
};

241 242
enum class KernelRegisteredType { FUNCTION, STRUCTURE };

243 244
class Kernel {
 public:
245
  // for map element construct
246 247
  Kernel() = default;

248
  explicit Kernel(KernelFn fn, void* variadic_fn)
249 250 251 252 253 254 255
      : fn_(fn), variadic_fn_(variadic_fn) {
    if (variadic_fn == nullptr) {
      kernel_registered_type_ = KernelRegisteredType::STRUCTURE;
    } else {
      kernel_registered_type_ = KernelRegisteredType::FUNCTION;
    }
  }
256 257 258

  void operator()(KernelContext* ctx) const { fn_(ctx); }

259 260 261 262 263 264
  template <typename Fn>
  Fn GetVariadicKernelFn() const {
    auto* func = reinterpret_cast<Fn>(variadic_fn_);
    return func;
  }

265 266 267 268
  KernelArgsDef* mutable_args_def() { return &args_def_; }

  const KernelArgsDef& args_def() const { return args_def_; }

269 270 271 272
  const TensorArgDef& InputAt(size_t idx) const {
    return args_def_.input_defs().at(idx);
  }

273 274
  TensorArgDef& InputAt(size_t idx) { return args_def_.input_defs().at(idx); }

275 276 277 278
  const TensorArgDef& OutputAt(size_t idx) const {
    return args_def_.output_defs().at(idx);
  }

279 280
  TensorArgDef& OutputAt(size_t idx) { return args_def_.output_defs().at(idx); }

281
  bool IsValid() const { return fn_ != nullptr; }
282

283 284 285 286
  KernelRegisteredType GetKernelRegisteredType() const {
    return kernel_registered_type_;
  }

287 288
  GetKernelTypeForVarFn get_kerneltype_forvar_fn_{nullptr};

289 290
 private:
  KernelFn fn_{nullptr};
291
  void* variadic_fn_ = nullptr;
292
  KernelArgsDef args_def_;
293
  KernelRegisteredType kernel_registered_type_ = KernelRegisteredType::FUNCTION;
294 295
};

296 297 298 299
using KernelKeyMap = paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>;

using KernelNameMap = paddle::flat_hash_map<std::string, KernelKeyMap>;

300 301 302 303 304 305 306 307
struct KernelResult {
  KernelResult(const Kernel& kernel, bool fallback_cpu)
      : kernel(kernel), has_fallback_cpu(fallback_cpu) {}

  const Kernel& kernel;
  bool has_fallback_cpu = false;
};

308 309 310 311 312 313 314 315 316 317
/**
 * Note: Each Computation need a basic kernel map that named by kernel_name.
 *       Such as for scale op, KernelMap contains a `scale` kernel map,
 *       if it still need other overload kernel, the op name can be
 *       `scale.***`.
 */
class KernelFactory {
 public:
  static KernelFactory& Instance();

318
  KernelNameMap& kernels() { return kernels_; }
319

320
  bool HasCompatiblePhiKernel(const std::string& op_type) const;
321

322 323
  bool HasStructuredKernel(const std::string& op_type) const;

324
  KernelResult SelectKernelOrThrowError(const std::string& kernel_name,
325
                                        const KernelKey& kernel_key) const;
326

327 328
  bool HasKernel(const std::string& kernel_name,
                 const KernelKey& kernel_key) const;
329

330 331
  const Kernel& SelectKernel(const std::string& kernel_name,
                             const KernelKey& kernel_key) const;
332

333
  KernelKeyMap SelectKernelMap(const std::string& kernel_name) const;
334

335 336 337
  const KernelArgsDef& GetFirstKernelArgsDef(
      const std::string& kernel_name) const;

338 339
  void AddToLowPrecisionKernelList(const std::string& name,
                                   const DataType& kernel_key_type);
340

341
  std::map<const std::string, OpCount> GetLowPrecisionKernelList();
342

343 344 345
 private:
  KernelFactory() = default;

346
  KernelNameMap kernels_;
347 348

  // Get the low precision kernel list of current module.
349
  std::map<const std::string, OpCount> low_precision_kernels_;
350 351 352 353 354 355 356 357
};

inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) {
  os << "(" << kernel_key.backend() << ", " << kernel_key.layout() << ", "
     << kernel_key.dtype() << ")";
  return os;
}

358 359
std::ostream& operator<<(std::ostream& os, AttributeType attr_type);

360 361 362 363
std::ostream& operator<<(std::ostream& os, const Kernel& kernel);

std::ostream& operator<<(std::ostream& os, KernelFactory& kernel_factory);

364
}  // namespace phi