kernel_factory.h 10.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
//   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <map>
18 19 20 21 22
#include <ostream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
23 24 25 26
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/compat/convert_utils.h"
27
#include "paddle/phi/core/compat/get_kerneltype_forvar_utils.h"
28 29
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/type_defs.h"
30
#include "paddle/phi/core/utils/data_type.h"
31 32
#include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h"
33
namespace phi {
34 35 36

using DataType = paddle::experimental::DataType;

37 38 39 40 41 42 43 44 45 46 47 48 49
struct OpCount {
  OpCount() {
    fp16_called_ = 0;
    bf16_called_ = 0;
    fp32_called_ = 0;
    other_called_ = 0;
  }
  int fp16_called_;
  int bf16_called_;
  int fp32_called_;
  int other_called_;
};

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
/**
 * [ Naming considerations ]
 *
 * The tensor operation library contains many kernels, and the computation
 * in each specific scenario is represented by an kernel.
 *
 * We directly named it `Kernel` instead of `Kernel`, the tensor operation
 * library here and fluid are independent, avoiding developers from
 * misunderstanding the relationship between the two concepts.
 */

class KernelContext;

class KernelKey {
 public:
  KernelKey() = default;

  KernelKey(Backend backend, DataLayout layout, DataType dtype)
      : backend_(backend), layout_(layout), dtype_(dtype) {}

70
  explicit KernelKey(const Place& place)
71 72 73 74
      : backend_(TransToPhiBackend(place)),
        layout_(DataLayout::ALL_LAYOUT),
        dtype_(DataType::ALL_DTYPE) {}

75
  explicit KernelKey(const int& dtype, const Place& place)
76 77 78 79
      : backend_(TransToPhiBackend(place)),
        layout_(DataLayout::ALL_LAYOUT),
        dtype_(phi::TransToPhiDataType(dtype)) {}

80 81 82
  explicit KernelKey(const Place& place,
                     const DataLayout& layout,
                     const DataType& dtype)
83 84
      : backend_(TransToPhiBackend(place)), layout_(layout), dtype_(dtype) {}

85 86 87 88
  Backend backend() const { return backend_; }
  DataLayout layout() const { return layout_; }
  DataType dtype() const { return dtype_; }

89 90 91 92
  void set_backend(const Backend& backend) { backend_ = backend; }
  void set_layout(const DataLayout& layout) { layout_ = layout; }
  void set_dtype(const DataType& dtype) { dtype_ = dtype; }

93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
  struct Hash {
    // Note: Now the number of bits we need does not exceed 32 bits, so there is
    // no need to use 64 bits. If needed in the future, it can be expanded,
    // but now we don’t over-design.
    uint32_t operator()(const KernelKey& key) const;
  };

  uint32_t hash_value() const { return Hash()(*this); }

  bool operator<(const KernelKey& key) const {
    return hash_value() < key.hash_value();
  }

  bool operator==(const KernelKey& key) const {
    return hash_value() == key.hash_value();
  }

  bool operator!=(const KernelKey& key) const {
    return hash_value() != key.hash_value();
  }

 private:
  // In total should be smaller than 32.
  constexpr static int kBackendBitLength = 8;
  constexpr static int kDataLayoutBitLength = 4;
  constexpr static int kDataTypeBitLength = 8;

  Backend backend_{Backend::UNDEFINED};
  DataLayout layout_{DataLayout::UNDEFINED};
  DataType dtype_{DataType::UNDEFINED};
};

// TODO(chenweihang): how deal with vector<Param>?
struct TensorArgDef {
  Backend backend;
  DataLayout layout;
  DataType dtype;
H
hong 已提交
130
  std::type_index type_index;
131

H
hong 已提交
132 133 134 135 136 137 138 139
  TensorArgDef(Backend in_backend,
               DataLayout in_layout,
               DataType in_dtype,
               std::type_index in_type_index)
      : backend(in_backend),
        layout(in_layout),
        dtype(in_dtype),
        type_index(in_type_index) {}
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156

  TensorArgDef& SetBackend(Backend in_backend) {
    backend = in_backend;
    return *this;
  }

  TensorArgDef& SetDataLayout(DataLayout in_layout) {
    layout = in_layout;
    return *this;
  }

  TensorArgDef& SetDataType(DataType in_dtype) {
    dtype = in_dtype;
    return *this;
  }
};

157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
// Align the original fluid Attribute type with lower overhead
enum class AttributeType {
  UNDEFINED = 0,
  BOOL,
  INT32,
  INT64,
  FLOAT32,
  FLOAT64,
  STRING,
  BOOLS,
  INT32S,
  INT64S,
  FLOAT32S,
  FLOAT64S,
  STRINGS,
  SCALAR,
  SCALARS,
  INT_ARRAY,
  DATA_TYPE,
  DATA_LAYOUT,
177
  PLACE
178 179
};

180
struct AttributeArgDef {
181
  AttributeType type_index;
182

183
  explicit AttributeArgDef(AttributeType type_index) : type_index(type_index) {}
184 185 186 187 188 189
};

class KernelArgsDef {
 public:
  KernelArgsDef() = default;

H
hong 已提交
190 191 192 193 194
  void AppendInput(Backend backend,
                   DataLayout layout,
                   DataType dtype,
                   std::type_index type_index) {
    input_defs_.emplace_back(TensorArgDef(backend, layout, dtype, type_index));
195 196
  }

H
hong 已提交
197 198 199 200 201
  void AppendOutput(Backend backend,
                    DataLayout layout,
                    DataType dtype,
                    std::type_index type_index) {
    output_defs_.emplace_back(TensorArgDef(backend, layout, dtype, type_index));
202 203
  }

204
  void AppendAttribute(AttributeType type_index) {
205 206 207
    attribute_defs_.emplace_back(AttributeArgDef(type_index));
  }

C
Chen Weihang 已提交
208
  const paddle::small_vector<TensorArgDef, kInputSmallVectorSize>& input_defs()
209
      const {
210 211 212
    return input_defs_;
  }

C
Chen Weihang 已提交
213 214
  const paddle::small_vector<TensorArgDef, kOutputSmallVectorSize>&
  output_defs() const {
215 216 217
    return output_defs_;
  }

C
Chen Weihang 已提交
218
  const paddle::small_vector<AttributeArgDef, kAttrSmallVectorSize>&
219
  attribute_defs() const {
220 221 222
    return attribute_defs_;
  }

C
Chen Weihang 已提交
223
  paddle::small_vector<TensorArgDef, kInputSmallVectorSize>& input_defs() {
224 225
    return input_defs_;
  }
226

C
Chen Weihang 已提交
227
  paddle::small_vector<TensorArgDef, kOutputSmallVectorSize>& output_defs() {
228 229
    return output_defs_;
  }
230

C
Chen Weihang 已提交
231 232
  paddle::small_vector<AttributeArgDef, kAttrSmallVectorSize>&
  attribute_defs() {
233 234 235 236
    return attribute_defs_;
  }

 private:
C
Chen Weihang 已提交
237 238 239
  paddle::small_vector<TensorArgDef, kInputSmallVectorSize> input_defs_{{}};
  paddle::small_vector<TensorArgDef, kOutputSmallVectorSize> output_defs_{{}};
  paddle::small_vector<AttributeArgDef, kAttrSmallVectorSize> attribute_defs_{
240
      {}};
241 242
};

243 244
enum class KernelRegisteredType { FUNCTION, STRUCTURE };

245 246
class Kernel {
 public:
247
  // for map element construct
248 249
  Kernel() = default;

250
  explicit Kernel(KernelFn fn, void* variadic_fn)
251 252 253 254 255 256 257
      : fn_(fn), variadic_fn_(variadic_fn) {
    if (variadic_fn == nullptr) {
      kernel_registered_type_ = KernelRegisteredType::STRUCTURE;
    } else {
      kernel_registered_type_ = KernelRegisteredType::FUNCTION;
    }
  }
258 259 260

  void operator()(KernelContext* ctx) const { fn_(ctx); }

261 262 263 264 265 266
  template <typename Fn>
  Fn GetVariadicKernelFn() const {
    auto* func = reinterpret_cast<Fn>(variadic_fn_);
    return func;
  }

267 268 269 270
  KernelArgsDef* mutable_args_def() { return &args_def_; }

  const KernelArgsDef& args_def() const { return args_def_; }

271 272 273 274
  const TensorArgDef& InputAt(size_t idx) const {
    return args_def_.input_defs().at(idx);
  }

275 276
  TensorArgDef& InputAt(size_t idx) { return args_def_.input_defs().at(idx); }

277 278 279 280
  const TensorArgDef& OutputAt(size_t idx) const {
    return args_def_.output_defs().at(idx);
  }

281 282
  TensorArgDef& OutputAt(size_t idx) { return args_def_.output_defs().at(idx); }

283
  bool IsValid() const { return fn_ != nullptr; }
284

285 286 287 288
  KernelRegisteredType GetKernelRegisteredType() const {
    return kernel_registered_type_;
  }

289 290
  GetKernelTypeForVarFn get_kerneltype_forvar_fn_{nullptr};

291 292
 private:
  KernelFn fn_{nullptr};
293
  void* variadic_fn_ = nullptr;
294
  KernelArgsDef args_def_;
295
  KernelRegisteredType kernel_registered_type_ = KernelRegisteredType::FUNCTION;
296 297
};

298 299 300 301
using KernelKeyMap = paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>;

using KernelNameMap = paddle::flat_hash_map<std::string, KernelKeyMap>;

302 303 304 305 306 307 308 309
struct KernelResult {
  KernelResult(const Kernel& kernel, bool fallback_cpu)
      : kernel(kernel), has_fallback_cpu(fallback_cpu) {}

  const Kernel& kernel;
  bool has_fallback_cpu = false;
};

310 311 312 313 314 315 316 317 318 319
/**
 * Note: Each Computation need a basic kernel map that named by kernel_name.
 *       Such as for scale op, KernelMap contains a `scale` kernel map,
 *       if it still need other overload kernel, the op name can be
 *       `scale.***`.
 */
class KernelFactory {
 public:
  static KernelFactory& Instance();

320
  KernelNameMap& kernels() { return kernels_; }
321

322
  bool HasCompatiblePhiKernel(const std::string& op_type) const;
323

324 325
  bool HasStructuredKernel(const std::string& op_type) const;

326
  KernelResult SelectKernelOrThrowError(const std::string& kernel_name,
327
                                        const KernelKey& kernel_key) const;
328

329 330
  bool HasKernel(const std::string& kernel_name,
                 const KernelKey& kernel_key) const;
331

332 333
  const Kernel& SelectKernel(const std::string& kernel_name,
                             const KernelKey& kernel_key) const;
334

335
  KernelKeyMap SelectKernelMap(const std::string& kernel_name) const;
336

337 338 339
  const KernelArgsDef& GetFirstKernelArgsDef(
      const std::string& kernel_name) const;

340 341 342 343
  void AddToLowPrecisionKernelList(
      const std::string& name,
      const paddle::experimental::DataType& kernel_key_type);

344
  std::map<const std::string, OpCount> GetLowPrecisionKernelList();
345

346 347 348
 private:
  KernelFactory() = default;

349
  KernelNameMap kernels_;
350 351

  // Get the low precision kernel list of current module.
352
  std::map<const std::string, OpCount> low_precision_kernels_;
353 354 355 356 357 358 359 360
};

inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) {
  os << "(" << kernel_key.backend() << ", " << kernel_key.layout() << ", "
     << kernel_key.dtype() << ")";
  return os;
}

361 362
std::ostream& operator<<(std::ostream& os, AttributeType attr_type);

363 364 365 366
std::ostream& operator<<(std::ostream& os, const Kernel& kernel);

std::ostream& operator<<(std::ostream& os, KernelFactory& kernel_factory);

367
}  // namespace phi