From 3020ad03d8a25d8cbe6f689479769504e0265f12 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Sat, 10 Jun 2023 09:45:11 +0800 Subject: [PATCH] [IR] Fix ir StorageManager bug (#54503) * fix ir storage bug * refine code * refine code * fix bug * refine code * refine code --- paddle/ir/core/enforce.h | 8 +++++--- paddle/ir/core/storage_manager.cc | 15 +++++++++------ paddle/ir/core/storage_manager.h | 7 +++++-- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/paddle/ir/core/enforce.h b/paddle/ir/core/enforce.h index e87ac0c41a0..10735297f30 100644 --- a/paddle/ir/core/enforce.h +++ b/paddle/ir/core/enforce.h @@ -25,8 +25,10 @@ // there is no equivalent intrinsics in msvc. #define UNLIKELY(condition) (condition) #endif - -inline bool is_error(bool stat) { return !stat; } +template +inline bool is_error(const T& stat) { + return !stat; +} namespace ir { class IrNotMetException : public std::exception { @@ -55,7 +57,7 @@ class IrNotMetException : public std::exception { #define IR_ENFORCE(COND, ...) \ do { \ - auto __cond__ = (COND); \ + bool __cond__(COND); \ if (UNLIKELY(is_error(__cond__))) { \ try { \ throw ir::IrNotMetException( \ diff --git a/paddle/ir/core/storage_manager.cc b/paddle/ir/core/storage_manager.cc index 8bc74b23c1d..8385c3d1fe4 100644 --- a/paddle/ir/core/storage_manager.cc +++ b/paddle/ir/core/storage_manager.cc @@ -25,11 +25,12 @@ namespace ir { struct ParametricStorageManager { using StorageBase = StorageManager::StorageBase; - ParametricStorageManager() {} + explicit ParametricStorageManager(std::function destroy) + : destroy_(destroy) {} ~ParametricStorageManager() { for (const auto &instance : parametric_instances_) { - delete instance.second; + destroy_(instance.second); } parametric_instances_.clear(); } @@ -37,7 +38,7 @@ struct ParametricStorageManager { // Get the storage of parametric type, if not in the cache, create and // insert the cache. StorageBase *GetOrCreate(std::size_t hash_value, - std::function equal_func, + std::function equal_func, std::function constructor) { if (parametric_instances_.count(hash_value) != 0) { auto pr = parametric_instances_.equal_range(hash_value); @@ -62,6 +63,7 @@ struct ParametricStorageManager { // In order to prevent hash conflicts, the unordered_multimap data structure // is used for storage. std::unordered_multimap parametric_instances_; + std::function destroy_; }; StorageManager::StorageManager() {} @@ -95,12 +97,13 @@ StorageManager::StorageBase *StorageManager::GetParameterlessStorageImpl( return parameterless_instance; } -void StorageManager::RegisterParametricStorageImpl(TypeId type_id) { +void StorageManager::RegisterParametricStorageImpl( + TypeId type_id, std::function destroy) { std::lock_guard guard(parametric_instance_lock_); VLOG(4) << "Register a parametric storage of: [TypeId_hash=" << std::hash()(type_id) << "]."; - parametric_instance_.emplace(type_id, - std::make_unique()); + parametric_instance_.emplace( + type_id, std::make_unique(destroy)); } void StorageManager::RegisterParameterlessStorageImpl( diff --git a/paddle/ir/core/storage_manager.h b/paddle/ir/core/storage_manager.h index 6b20afb8a80..9ff7f76e98d 100644 --- a/paddle/ir/core/storage_manager.h +++ b/paddle/ir/core/storage_manager.h @@ -100,7 +100,9 @@ class StorageManager { /// template void RegisterParametricStorage(TypeId type_id) { - return RegisterParametricStorageImpl(type_id); + return RegisterParametricStorageImpl(type_id, [](StorageBase *storage) { + delete static_cast(storage); + }); } /// @@ -129,7 +131,8 @@ class StorageManager { StorageBase *GetParameterlessStorageImpl(TypeId type_id); - void RegisterParametricStorageImpl(TypeId type_id); + void RegisterParametricStorageImpl( + TypeId type_id, std::function destroy); void RegisterParameterlessStorageImpl( TypeId type_id, std::function constructor); -- GitLab