diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 00efffff1440c8151a0ec65e0b7e68b3c42173bd..d97e12fbe450688e3bd94791685bd41954a4cd3d 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -91,16 +91,14 @@ MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() { return *global; } -static void RegisterDialects() { - static bool init_once = []() { - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - return true; - }(); - (void)init_once; +static void RegisterDialects(mlir::DialectRegistry& registry) { + // clang-format off + registry.insert(); + // clang-format on } Status MlirFunctionOptimizationPass::Run( @@ -126,9 +124,8 @@ Status MlirFunctionOptimizationPass::Run( << " passes)"; GraphDebugInfo debug_info; - RegisterDialects(); mlir::MLIRContext context; - context.loadAllGloballyRegisteredDialects(); + RegisterDialects(context.getDialectRegistry()); GraphImportConfig import_config; import_config.graph_as_function = true; import_config.control_outputs = *control_ret_node_names; @@ -207,9 +204,8 @@ Status MlirV1CompatGraphOptimizationPass::Run( << " passes)"; GraphDebugInfo debug_info; - RegisterDialects(); mlir::MLIRContext context; - context.loadAllGloballyRegisteredDialects(); + RegisterDialects(context.getDialectRegistry()); GraphImportConfig import_config; import_config.upgrade_legacy = true; // Restrict functionalization to TPU nodes to avoid problems in v1 session