storage_manager.cc 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/ir/core/storage_manager.h"
16

17 18 19
#include <memory>
#include <unordered_map>

20 21
#include "paddle/ir/core/enforce.h"

22 23
namespace ir {
// This is a structure for creating, caching, and looking up Storage of
C
co63oc 已提交
24
// parametric types.
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
struct ParametricStorageManager {
  using StorageBase = StorageManager::StorageBase;

  ParametricStorageManager() {}

  ~ParametricStorageManager() {
    for (const auto &instance : parametric_instances_) {
      delete instance.second;
    }
    parametric_instances_.clear();
  }

  // Get the storage of parametric type, if not in the cache, create and
  // insert the cache.
  StorageBase *GetOrCreate(std::size_t hash_value,
                           std::function<bool(const StorageBase *)> equal_func,
                           std::function<StorageBase *()> constructor) {
    if (parametric_instances_.count(hash_value) != 0) {
      auto pr = parametric_instances_.equal_range(hash_value);
      while (pr.first != pr.second) {
        if (equal_func(pr.first->second)) {
C
co63oc 已提交
46
          VLOG(4) << "Found a cached parametric storage of: [param_hash="
47 48 49 50 51 52 53 54
                  << hash_value << ", storage_ptr=" << pr.first->second << "].";
          return pr.first->second;
        }
        ++pr.first;
      }
    }
    StorageBase *storage = constructor();
    parametric_instances_.emplace(hash_value, storage);
C
co63oc 已提交
55
    VLOG(4) << "No cache found, construct and cache a new parametric storage "
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
               "of: [param_hash="
            << hash_value << ", storage_ptr=" << storage << "].";
    return storage;
  }

 private:
  // In order to prevent hash conflicts, the unordered_multimap data structure
  // is used for storage.
  std::unordered_multimap<size_t, StorageBase *> parametric_instances_;
};

StorageManager::StorageManager() {}

StorageManager::~StorageManager() = default;

Z
zhangbo9674 已提交
71
StorageManager::StorageBase *StorageManager::GetParametricStorageImpl(
72 73 74 75 76
    TypeId type_id,
    std::size_t hash_value,
    std::function<bool(const StorageBase *)> equal_func,
    std::function<StorageBase *()> constructor) {
  std::lock_guard<ir::SpinLock> guard(parametric_instance_lock_);
C
co63oc 已提交
77
  VLOG(4) << "Try to get a parametric storage of: [TypeId_hash="
78 79
          << std::hash<ir::TypeId>()(type_id) << ", param_hash=" << hash_value
          << "].";
H
hong 已提交
80
  if (parametric_instance_.find(type_id) == parametric_instance_.end()) {
81
    IR_THROW("The input data pointer is null.");
H
hong 已提交
82
  }
83 84 85 86
  ParametricStorageManager &parametric_storage = *parametric_instance_[type_id];
  return parametric_storage.GetOrCreate(hash_value, equal_func, constructor);
}

Z
zhangbo9674 已提交
87
StorageManager::StorageBase *StorageManager::GetParameterlessStorageImpl(
88
    TypeId type_id) {
89 90
  std::lock_guard<ir::SpinLock> guard(parameterless_instance_lock_);
  VLOG(4) << "Try to get a parameterless storage of: [TypeId_hash="
91
          << std::hash<ir::TypeId>()(type_id) << "].";
92
  if (parameterless_instance_.find(type_id) == parameterless_instance_.end())
93
    IR_THROW("TypeId not found in IrContext.");
94
  StorageBase *parameterless_instance = parameterless_instance_[type_id];
95 96 97
  return parameterless_instance;
}

Z
zhangbo9674 已提交
98
void StorageManager::RegisterParametricStorageImpl(TypeId type_id) {
99
  std::lock_guard<ir::SpinLock> guard(parametric_instance_lock_);
C
co63oc 已提交
100
  VLOG(4) << "Register a parametric storage of: [TypeId_hash="
101 102 103 104 105
          << std::hash<ir::TypeId>()(type_id) << "].";
  parametric_instance_.emplace(type_id,
                               std::make_unique<ParametricStorageManager>());
}

Z
zhangbo9674 已提交
106
void StorageManager::RegisterParameterlessStorageImpl(
107
    TypeId type_id, std::function<StorageBase *()> constructor) {
108 109
  std::lock_guard<ir::SpinLock> guard(parameterless_instance_lock_);
  VLOG(4) << "Register a parameterless storage of: [TypeId_hash="
110
          << std::hash<ir::TypeId>()(type_id) << "].";
111
  if (parameterless_instance_.find(type_id) != parameterless_instance_.end())
112
    IR_THROW("storage class already registered");
113
  parameterless_instance_.emplace(type_id, constructor());
114 115 116
}

}  // namespace ir