diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.cc index c72ee0b6b2f994eb4ce82ffc5f56e3628305b0d4..0ded4ee1a5de882c6cf3fd04c37030b563104fcd 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.cc @@ -15,14 +15,12 @@ #include "paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.h" #include "paddle/ir/core/enforce.h" #include "paddle/ir/core/ir_context.h" -// #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" namespace paddle { namespace dialect { APIBuilder::APIBuilder() : builder_(nullptr) { ctx_ = ir::IrContext::Instance(); - // ctx_->GetOrRegisterDialect(); } void APIBuilder::SetProgram(ir::Program* program) { diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 1f42052ca7aaa6236e53459b2766983f5e07b6da..6c6957c3e00e08f08872bb28232c17861a887d01 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -410,6 +410,10 @@ void BindUtils(pybind11::module *m) { []() { APIBuilder::Instance().ResetInsertionPointToStart(); }); m->def("reset_insertion_point_to_end", []() { APIBuilder::Instance().ResetInsertionPointToEnd(); }); + m->def("register_paddle_dialect", []() { + ir::IrContext::Instance() + ->GetOrRegisterDialect(); + }); m->def( "translate_to_new_ir", [](const ::paddle::framework::ProgramDesc &legacy_program) { diff --git a/python/paddle/ir/__init__.py b/python/paddle/ir/__init__.py index be8ddeba2298f97a48c12ae9fed4dc9c00fd5788..4fee1c1a064c5d3ae6cee6cb886ef2b44c8bb5b5 100644 --- a/python/paddle/ir/__init__.py +++ b/python/paddle/ir/__init__.py @@ -28,6 +28,7 @@ from paddle.fluid.libpaddle.ir import ( reset_insertion_point_to_start, reset_insertion_point_to_end, check_unregistered_ops, + register_paddle_dialect, PassManager, ) # noqa: F401 diff --git a/python/paddle/new_ir_utils.py b/python/paddle/new_ir_utils.py index bb016158d16160e9eaaecb4e8f2b050afd45724f..443ac48ae829c43f454699a716a89adccce03975 100644 --- a/python/paddle/new_ir_utils.py +++ b/python/paddle/new_ir_utils.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. - import paddle def _switch_to_new_ir(): if paddle.ir.core._use_new_ir_api(): paddle.framework.set_flags({"FLAGS_enable_new_ir_in_executor": True}) + paddle.ir.register_paddle_dialect() paddle.static.Program = paddle.ir.Program paddle.fluid.Program = paddle.ir.Program paddle.fluid.program_guard = paddle.ir.core.program_guard