未验证 提交 d73db135 编写于 作者: 王明冬 提交者: GitHub

[IR] fine-tune the interface of ir-context class. (#54031)

上级 e862753c
......@@ -99,6 +99,10 @@ class IrContextImpl {
return nullptr;
}
bool IsOpInfoRegistered(const std::string &name) {
return registed_op_infos_.find(name) != registed_op_infos_.end();
}
void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) {
std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
VLOG(4) << "Register an operation of: [Name=" << name
......@@ -125,7 +129,11 @@ class IrContextImpl {
registed_dialect_.emplace(name, dialect);
}
Dialect *GetDialect(std::string name) {
bool IsDialectRegistered(const std::string &name) {
return registed_dialect_.find(name) != registed_dialect_.end();
}
Dialect *GetDialect(const std::string &name) {
std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
auto iter = registed_dialect_.find(name);
if (iter != registed_dialect_.end()) {
......@@ -221,17 +229,15 @@ AbstractAttribute *IrContext::GetRegisteredAbstractAttribute(TypeId id) {
}
Dialect *IrContext::GetOrRegisterDialect(
std::string dialect_name, std::function<Dialect *()> constructor) {
const std::string &dialect_name, std::function<Dialect *()> constructor) {
VLOG(4) << "Try to get or register a Dialect of: [name=" << dialect_name
<< "].";
Dialect *dialect = impl().GetDialect(dialect_name);
if (dialect == nullptr) {
if (!impl().IsDialectRegistered(dialect_name)) {
VLOG(4) << "Create and register a new Dialect of: [name=" << dialect_name
<< "].";
dialect = constructor();
impl().RegisterDialect(dialect_name, dialect);
impl().RegisterDialect(dialect_name, constructor());
}
return dialect;
return impl().GetDialect(dialect_name);
}
std::vector<Dialect *> IrContext::GetRegisteredDialects() {
......@@ -271,7 +277,9 @@ void IrContext::RegisterOpInfo(Dialect *dialect,
size_t attributes_num,
const char **attributes_name,
VerifyPtr verify) {
if (GetRegisteredOpInfo(name) == nullptr) {
if (impl().IsOpInfoRegistered(name)) {
LOG(WARNING) << name << " op already registered.";
} else {
OpInfoImpl *opinfo = OpInfoImpl::create(dialect,
op_id,
name,
......@@ -281,9 +289,7 @@ void IrContext::RegisterOpInfo(Dialect *dialect,
attributes_name,
verify);
impl().RegisterOpInfo(name, opinfo);
VLOG(4) << "Op " << name << " registered into IrContext. --->";
} else {
LOG(WARNING) << name << " op already registered.";
VLOG(4) << name << " op registered into IrContext. --->";
}
}
......
......@@ -143,7 +143,7 @@ class IrContext {
///
/// \return The dialect named "dialect_name" in the context.
///
Dialect *GetOrRegisterDialect(std::string dialect_name,
Dialect *GetOrRegisterDialect(const std::string &dialect_name,
std::function<Dialect *()> constructor);
///
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册