未验证 提交 3020ad03 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Fix ir StorageManager bug (#54503)

* fix ir storage bug

* refine code

* refine code

* fix bug

* refine code

* refine code
上级 8dfcf03e
...@@ -25,8 +25,10 @@ ...@@ -25,8 +25,10 @@
// there is no equivalent intrinsics in msvc. // there is no equivalent intrinsics in msvc.
#define UNLIKELY(condition) (condition) #define UNLIKELY(condition) (condition)
#endif #endif
template <typename T>
inline bool is_error(bool stat) { return !stat; } inline bool is_error(const T& stat) {
return !stat;
}
namespace ir { namespace ir {
class IrNotMetException : public std::exception { class IrNotMetException : public std::exception {
...@@ -55,7 +57,7 @@ class IrNotMetException : public std::exception { ...@@ -55,7 +57,7 @@ class IrNotMetException : public std::exception {
#define IR_ENFORCE(COND, ...) \ #define IR_ENFORCE(COND, ...) \
do { \ do { \
auto __cond__ = (COND); \ bool __cond__(COND); \
if (UNLIKELY(is_error(__cond__))) { \ if (UNLIKELY(is_error(__cond__))) { \
try { \ try { \
throw ir::IrNotMetException( \ throw ir::IrNotMetException( \
......
...@@ -25,11 +25,12 @@ namespace ir { ...@@ -25,11 +25,12 @@ namespace ir {
struct ParametricStorageManager { struct ParametricStorageManager {
using StorageBase = StorageManager::StorageBase; using StorageBase = StorageManager::StorageBase;
ParametricStorageManager() {} explicit ParametricStorageManager(std::function<void(StorageBase *)> destroy)
: destroy_(destroy) {}
~ParametricStorageManager() { ~ParametricStorageManager() {
for (const auto &instance : parametric_instances_) { for (const auto &instance : parametric_instances_) {
delete instance.second; destroy_(instance.second);
} }
parametric_instances_.clear(); parametric_instances_.clear();
} }
...@@ -37,7 +38,7 @@ struct ParametricStorageManager { ...@@ -37,7 +38,7 @@ struct ParametricStorageManager {
// Get the storage of parametric type, if not in the cache, create and // Get the storage of parametric type, if not in the cache, create and
// insert the cache. // insert the cache.
StorageBase *GetOrCreate(std::size_t hash_value, StorageBase *GetOrCreate(std::size_t hash_value,
std::function<bool(const StorageBase *)> equal_func, std::function<bool(StorageBase *)> equal_func,
std::function<StorageBase *()> constructor) { std::function<StorageBase *()> constructor) {
if (parametric_instances_.count(hash_value) != 0) { if (parametric_instances_.count(hash_value) != 0) {
auto pr = parametric_instances_.equal_range(hash_value); auto pr = parametric_instances_.equal_range(hash_value);
...@@ -62,6 +63,7 @@ struct ParametricStorageManager { ...@@ -62,6 +63,7 @@ struct ParametricStorageManager {
// In order to prevent hash conflicts, the unordered_multimap data structure // In order to prevent hash conflicts, the unordered_multimap data structure
// is used for storage. // is used for storage.
std::unordered_multimap<size_t, StorageBase *> parametric_instances_; std::unordered_multimap<size_t, StorageBase *> parametric_instances_;
std::function<void(StorageBase *)> destroy_;
}; };
StorageManager::StorageManager() {} StorageManager::StorageManager() {}
...@@ -95,12 +97,13 @@ StorageManager::StorageBase *StorageManager::GetParameterlessStorageImpl( ...@@ -95,12 +97,13 @@ StorageManager::StorageBase *StorageManager::GetParameterlessStorageImpl(
return parameterless_instance; return parameterless_instance;
} }
void StorageManager::RegisterParametricStorageImpl(TypeId type_id) { void StorageManager::RegisterParametricStorageImpl(
TypeId type_id, std::function<void(StorageBase *)> destroy) {
std::lock_guard<ir::SpinLock> guard(parametric_instance_lock_); std::lock_guard<ir::SpinLock> guard(parametric_instance_lock_);
VLOG(4) << "Register a parametric storage of: [TypeId_hash=" VLOG(4) << "Register a parametric storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "]."; << std::hash<ir::TypeId>()(type_id) << "].";
parametric_instance_.emplace(type_id, parametric_instance_.emplace(
std::make_unique<ParametricStorageManager>()); type_id, std::make_unique<ParametricStorageManager>(destroy));
} }
void StorageManager::RegisterParameterlessStorageImpl( void StorageManager::RegisterParameterlessStorageImpl(
......
...@@ -100,7 +100,9 @@ class StorageManager { ...@@ -100,7 +100,9 @@ class StorageManager {
/// ///
template <typename Storage> template <typename Storage>
void RegisterParametricStorage(TypeId type_id) { void RegisterParametricStorage(TypeId type_id) {
return RegisterParametricStorageImpl(type_id); return RegisterParametricStorageImpl(type_id, [](StorageBase *storage) {
delete static_cast<Storage *>(storage);
});
} }
/// ///
...@@ -129,7 +131,8 @@ class StorageManager { ...@@ -129,7 +131,8 @@ class StorageManager {
StorageBase *GetParameterlessStorageImpl(TypeId type_id); StorageBase *GetParameterlessStorageImpl(TypeId type_id);
void RegisterParametricStorageImpl(TypeId type_id); void RegisterParametricStorageImpl(
TypeId type_id, std::function<void(StorageBase *)> destroy);
void RegisterParameterlessStorageImpl( void RegisterParameterlessStorageImpl(
TypeId type_id, std::function<StorageBase *()> constructor); TypeId type_id, std::function<StorageBase *()> constructor);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册