diff --git a/paddle/ir/core/enforce.h b/paddle/ir/core/enforce.h index e87ac0c41a07ee19558401858f5269587fe3c039..10735297f305d54d6ceb4daa7222ff92ac10f9fd 100644 --- a/paddle/ir/core/enforce.h +++ b/paddle/ir/core/enforce.h @@ -25,8 +25,10 @@ // there is no equivalent intrinsics in msvc. #define UNLIKELY(condition) (condition) #endif - -inline bool is_error(bool stat) { return !stat; } +template +inline bool is_error(const T& stat) { + return !stat; +} namespace ir { class IrNotMetException : public std::exception { @@ -55,7 +57,7 @@ class IrNotMetException : public std::exception { #define IR_ENFORCE(COND, ...) \ do { \ - auto __cond__ = (COND); \ + bool __cond__(COND); \ if (UNLIKELY(is_error(__cond__))) { \ try { \ throw ir::IrNotMetException( \ diff --git a/paddle/ir/core/storage_manager.cc b/paddle/ir/core/storage_manager.cc index 8bc74b23c1d10a4dfdf9d9d7f0847237fb36ead0..8385c3d1fe45987679bfd622c9970fa6e5796bc4 100644 --- a/paddle/ir/core/storage_manager.cc +++ b/paddle/ir/core/storage_manager.cc @@ -25,11 +25,12 @@ namespace ir { struct ParametricStorageManager { using StorageBase = StorageManager::StorageBase; - ParametricStorageManager() {} + explicit ParametricStorageManager(std::function destroy) + : destroy_(destroy) {} ~ParametricStorageManager() { for (const auto &instance : parametric_instances_) { - delete instance.second; + destroy_(instance.second); } parametric_instances_.clear(); } @@ -37,7 +38,7 @@ struct ParametricStorageManager { // Get the storage of parametric type, if not in the cache, create and // insert the cache. StorageBase *GetOrCreate(std::size_t hash_value, - std::function equal_func, + std::function equal_func, std::function constructor) { if (parametric_instances_.count(hash_value) != 0) { auto pr = parametric_instances_.equal_range(hash_value); @@ -62,6 +63,7 @@ struct ParametricStorageManager { // In order to prevent hash conflicts, the unordered_multimap data structure // is used for storage. std::unordered_multimap parametric_instances_; + std::function destroy_; }; StorageManager::StorageManager() {} @@ -95,12 +97,13 @@ StorageManager::StorageBase *StorageManager::GetParameterlessStorageImpl( return parameterless_instance; } -void StorageManager::RegisterParametricStorageImpl(TypeId type_id) { +void StorageManager::RegisterParametricStorageImpl( + TypeId type_id, std::function destroy) { std::lock_guard guard(parametric_instance_lock_); VLOG(4) << "Register a parametric storage of: [TypeId_hash=" << std::hash()(type_id) << "]."; - parametric_instance_.emplace(type_id, - std::make_unique()); + parametric_instance_.emplace( + type_id, std::make_unique(destroy)); } void StorageManager::RegisterParameterlessStorageImpl( diff --git a/paddle/ir/core/storage_manager.h b/paddle/ir/core/storage_manager.h index 6b20afb8a8067cca8b72def14f8cb9058eccfce6..9ff7f76e98def5e7870fbb324173cd83dd36bfa3 100644 --- a/paddle/ir/core/storage_manager.h +++ b/paddle/ir/core/storage_manager.h @@ -100,7 +100,9 @@ class StorageManager { /// template void RegisterParametricStorage(TypeId type_id) { - return RegisterParametricStorageImpl(type_id); + return RegisterParametricStorageImpl(type_id, [](StorageBase *storage) { + delete static_cast(storage); + }); } /// @@ -129,7 +131,8 @@ class StorageManager { StorageBase *GetParameterlessStorageImpl(TypeId type_id); - void RegisterParametricStorageImpl(TypeId type_id); + void RegisterParametricStorageImpl( + TypeId type_id, std::function destroy); void RegisterParameterlessStorageImpl( TypeId type_id, std::function constructor);