diff --git a/paddle/ir/ir_context.cc b/paddle/ir/ir_context.cc index be87ea6bc7f4c913bb575aae5863a20e74ca1dd0..dd922a7359f428897de1b718243ebe89c5dc8296 100644 --- a/paddle/ir/ir_context.cc +++ b/paddle/ir/ir_context.cc @@ -99,6 +99,10 @@ class IrContextImpl { return nullptr; } + bool IsOpInfoRegistered(const std::string &name) { + return registed_op_infos_.find(name) != registed_op_infos_.end(); + } + void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) { std::lock_guard guard(registed_op_infos_lock_); VLOG(4) << "Register an operation of: [Name=" << name @@ -125,7 +129,11 @@ class IrContextImpl { registed_dialect_.emplace(name, dialect); } - Dialect *GetDialect(std::string name) { + bool IsDialectRegistered(const std::string &name) { + return registed_dialect_.find(name) != registed_dialect_.end(); + } + + Dialect *GetDialect(const std::string &name) { std::lock_guard guard(registed_dialect_lock_); auto iter = registed_dialect_.find(name); if (iter != registed_dialect_.end()) { @@ -221,17 +229,15 @@ AbstractAttribute *IrContext::GetRegisteredAbstractAttribute(TypeId id) { } Dialect *IrContext::GetOrRegisterDialect( - std::string dialect_name, std::function constructor) { + const std::string &dialect_name, std::function constructor) { VLOG(4) << "Try to get or register a Dialect of: [name=" << dialect_name << "]."; - Dialect *dialect = impl().GetDialect(dialect_name); - if (dialect == nullptr) { + if (!impl().IsDialectRegistered(dialect_name)) { VLOG(4) << "Create and register a new Dialect of: [name=" << dialect_name << "]."; - dialect = constructor(); - impl().RegisterDialect(dialect_name, dialect); + impl().RegisterDialect(dialect_name, constructor()); } - return dialect; + return impl().GetDialect(dialect_name); } std::vector IrContext::GetRegisteredDialects() { @@ -271,7 +277,9 @@ void IrContext::RegisterOpInfo(Dialect *dialect, size_t attributes_num, const char **attributes_name, VerifyPtr verify) { - if (GetRegisteredOpInfo(name) == nullptr) { + if (impl().IsOpInfoRegistered(name)) { + LOG(WARNING) << name << " op already registered."; + } else { OpInfoImpl *opinfo = OpInfoImpl::create(dialect, op_id, name, @@ -281,9 +289,7 @@ void IrContext::RegisterOpInfo(Dialect *dialect, attributes_name, verify); impl().RegisterOpInfo(name, opinfo); - VLOG(4) << "Op " << name << " registered into IrContext. --->"; - } else { - LOG(WARNING) << name << " op already registered."; + VLOG(4) << name << " op registered into IrContext. --->"; } } diff --git a/paddle/ir/ir_context.h b/paddle/ir/ir_context.h index 08c7997d3b1fc85366a36553cd1f364394c44e35..c5fb7fa5550b6f89b12ce9658c1b959510f9f21b 100644 --- a/paddle/ir/ir_context.h +++ b/paddle/ir/ir_context.h @@ -143,7 +143,7 @@ class IrContext { /// /// \return The dialect named "dialect_name" in the context. /// - Dialect *GetOrRegisterDialect(std::string dialect_name, + Dialect *GetOrRegisterDialect(const std::string &dialect_name, std::function constructor); ///