storage_manager.h 4.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <type_traits>
#include <unordered_map>

21 22
#include "paddle/ir/core/spin_lock.h"
#include "paddle/ir/core/type_id.h"
23 24 25 26 27 28 29 30 31 32 33 34 35

namespace ir {
///
/// \brief The implementation of the class StorageManager.
///
// struct StorageManagerImpl;
struct ParametricStorageManager;

///
/// \brief A utility class for getting or creating Storage class instances.
/// Storage class must be a derived class of StorageManager::StorageBase.
/// There are two types of Storage class:
/// One is a parameterless type, which can directly obtain an instance through
C
co63oc 已提交
36
/// the get method; The other is a parametric type, which needs to comply with
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
/// the following conditions: (1) Need to define a type alias called ParamKey,
/// it serves as the unique identifier for the Storage class; (2) Need to
/// provide a hash method on the ParamKey for storage and access; (3) Need to
/// provide method 'bool operator==(const ParamKey &) const', used to compare
/// Storage instance and ParamKey instance.
///
class StorageManager {
 public:
  ///
  /// \brief This class is the base class of all storage classes,
  /// and any type of storage needs to inherit from this class.
  ///
  class StorageBase {
   protected:
    StorageBase() = default;
  };

  StorageManager();

  ~StorageManager();

  ///
  /// \brief Get a unique storage instance of parametric Type.
  ///
  /// \param init_func Used to initialize a newly inserted storage instance.
  /// \param type_id The type id of the AbstractType.
  /// \param args Parameters of the wrapped function.
  /// \return A uniqued instance of Storage.
  ///
  template <typename Storage, typename... Args>
Z
zhangbo9674 已提交
67 68 69
  Storage *GetParametricStorage(std::function<void(Storage *)> init_func,
                                TypeId type_id,
                                Args &&...args) {
70 71 72 73 74 75 76 77 78 79 80
    typename Storage::ParamKey param =
        typename Storage::ParamKey(std::forward<Args>(args)...);
    std::size_t hash_value = Storage::HashValue(param);
    auto equal_func = [&param](const StorageBase *existing) {
      return static_cast<const Storage &>(*existing) == param;
    };
    auto constructor = [&]() {
      auto *storage = Storage::Construct(param);
      if (init_func) init_func(storage);
      return storage;
    };
Z
zhangbo9674 已提交
81 82
    return static_cast<Storage *>(
        GetParametricStorageImpl(type_id, hash_value, equal_func, constructor));
83 84 85 86 87 88 89 90 91
  }

  ///
  /// \brief Get a unique storage instance of parameterless Type.
  ///
  /// \param type_id The type id of the AbstractType.
  /// \return A uniqued instance of Storage.
  ///
  template <typename Storage>
Z
zhangbo9674 已提交
92 93
  Storage *GetParameterlessStorage(TypeId type_id) {
    return static_cast<Storage *>(GetParameterlessStorageImpl(type_id));
94 95 96 97 98 99 100 101
  }

  ///
  /// \brief Register a new parametric storage class.
  ///
  /// \param type_id The type id of the AbstractType.
  ///
  template <typename Storage>
Z
zhangbo9674 已提交
102 103
  void RegisterParametricStorage(TypeId type_id) {
    return RegisterParametricStorageImpl(type_id);
104 105 106 107 108 109 110 111 112
  }

  ///
  /// \brief Register a new parameterless storage class.
  ///
  /// \param type_id The type id of the AbstractType.
  /// \param init_func Used to initialize a newly inserted storage instance.
  ///
  template <typename Storage>
Z
zhangbo9674 已提交
113 114
  void RegisterParameterlessStorage(TypeId type_id,
                                    std::function<void(Storage *)> init_func) {
115 116 117 118 119
    auto constructor = [&]() {
      auto *storage = new Storage();
      if (init_func) init_func(storage);
      return storage;
    };
Z
zhangbo9674 已提交
120
    RegisterParameterlessStorageImpl(type_id, constructor);
121 122 123
  }

 private:
Z
zhangbo9674 已提交
124
  StorageBase *GetParametricStorageImpl(
125 126 127 128 129
      TypeId type_id,
      std::size_t hash_value,
      std::function<bool(const StorageBase *)> equal_func,
      std::function<StorageBase *()> constructor);

Z
zhangbo9674 已提交
130
  StorageBase *GetParameterlessStorageImpl(TypeId type_id);
131

Z
zhangbo9674 已提交
132
  void RegisterParametricStorageImpl(TypeId type_id);
133

Z
zhangbo9674 已提交
134
  void RegisterParameterlessStorageImpl(
135 136
      TypeId type_id, std::function<StorageBase *()> constructor);

C
co63oc 已提交
137
  // This map is a mapping between type id and parametric type storage.
138 139 140 141 142 143
  std::unordered_map<TypeId, std::unique_ptr<ParametricStorageManager>>
      parametric_instance_;

  ir::SpinLock parametric_instance_lock_;

  // This map is a mapping between type id and parameterless type storage.
144
  std::unordered_map<TypeId, StorageBase *> parameterless_instance_;
145

146
  ir::SpinLock parameterless_instance_lock_;
147 148 149
};

}  // namespace ir