From d73db135ed9c867d508437166188e9137cf2523f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Wed, 24 May 2023 15:59:53 +0800 Subject: [PATCH] [IR] fine-tune the interface of ir-context class. (#54031) --- paddle/ir/ir_context.cc | 28 +++++++++++++++++----------- paddle/ir/ir_context.h | 2 +- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/paddle/ir/ir_context.cc b/paddle/ir/ir_context.cc index be87ea6bc7f..dd922a7359f 100644 --- a/paddle/ir/ir_context.cc +++ b/paddle/ir/ir_context.cc @@ -99,6 +99,10 @@ class IrContextImpl { return nullptr; } + bool IsOpInfoRegistered(const std::string &name) { + return registed_op_infos_.find(name) != registed_op_infos_.end(); + } + void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) { std::lock_guard guard(registed_op_infos_lock_); VLOG(4) << "Register an operation of: [Name=" << name @@ -125,7 +129,11 @@ class IrContextImpl { registed_dialect_.emplace(name, dialect); } - Dialect *GetDialect(std::string name) { + bool IsDialectRegistered(const std::string &name) { + return registed_dialect_.find(name) != registed_dialect_.end(); + } + + Dialect *GetDialect(const std::string &name) { std::lock_guard guard(registed_dialect_lock_); auto iter = registed_dialect_.find(name); if (iter != registed_dialect_.end()) { @@ -221,17 +229,15 @@ AbstractAttribute *IrContext::GetRegisteredAbstractAttribute(TypeId id) { } Dialect *IrContext::GetOrRegisterDialect( - std::string dialect_name, std::function constructor) { + const std::string &dialect_name, std::function constructor) { VLOG(4) << "Try to get or register a Dialect of: [name=" << dialect_name << "]."; - Dialect *dialect = impl().GetDialect(dialect_name); - if (dialect == nullptr) { + if (!impl().IsDialectRegistered(dialect_name)) { VLOG(4) << "Create and register a new Dialect of: [name=" << dialect_name << "]."; - dialect = constructor(); - impl().RegisterDialect(dialect_name, dialect); + impl().RegisterDialect(dialect_name, constructor()); } - return dialect; + return impl().GetDialect(dialect_name); } std::vector IrContext::GetRegisteredDialects() { @@ -271,7 +277,9 @@ void IrContext::RegisterOpInfo(Dialect *dialect, size_t attributes_num, const char **attributes_name, VerifyPtr verify) { - if (GetRegisteredOpInfo(name) == nullptr) { + if (impl().IsOpInfoRegistered(name)) { + LOG(WARNING) << name << " op already registered."; + } else { OpInfoImpl *opinfo = OpInfoImpl::create(dialect, op_id, name, @@ -281,9 +289,7 @@ void IrContext::RegisterOpInfo(Dialect *dialect, attributes_name, verify); impl().RegisterOpInfo(name, opinfo); - VLOG(4) << "Op " << name << " registered into IrContext. --->"; - } else { - LOG(WARNING) << name << " op already registered."; + VLOG(4) << name << " op registered into IrContext. --->"; } } diff --git a/paddle/ir/ir_context.h b/paddle/ir/ir_context.h index 08c7997d3b1..c5fb7fa5550 100644 --- a/paddle/ir/ir_context.h +++ b/paddle/ir/ir_context.h @@ -143,7 +143,7 @@ class IrContext { /// /// \return The dialect named "dialect_name" in the context. /// - Dialect *GetOrRegisterDialect(std::string dialect_name, + Dialect *GetOrRegisterDialect(const std::string &dialect_name, std::function constructor); /// -- GitLab