未验证 提交 ca5585e9 编写于 作者: Y YuanRisheng 提交者: GitHub

[BugFix]Fix test_build_model error (#56633)

* fix test bugs

* delete code
上级 ecff21e7
...@@ -15,14 +15,12 @@ ...@@ -15,14 +15,12 @@
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.h"
#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
// #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h"
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {
APIBuilder::APIBuilder() : builder_(nullptr) { APIBuilder::APIBuilder() : builder_(nullptr) {
ctx_ = ir::IrContext::Instance(); ctx_ = ir::IrContext::Instance();
// ctx_->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
} }
void APIBuilder::SetProgram(ir::Program* program) { void APIBuilder::SetProgram(ir::Program* program) {
......
...@@ -410,6 +410,10 @@ void BindUtils(pybind11::module *m) { ...@@ -410,6 +410,10 @@ void BindUtils(pybind11::module *m) {
[]() { APIBuilder::Instance().ResetInsertionPointToStart(); }); []() { APIBuilder::Instance().ResetInsertionPointToStart(); });
m->def("reset_insertion_point_to_end", m->def("reset_insertion_point_to_end",
[]() { APIBuilder::Instance().ResetInsertionPointToEnd(); }); []() { APIBuilder::Instance().ResetInsertionPointToEnd(); });
m->def("register_paddle_dialect", []() {
ir::IrContext::Instance()
->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
});
m->def( m->def(
"translate_to_new_ir", "translate_to_new_ir",
[](const ::paddle::framework::ProgramDesc &legacy_program) { [](const ::paddle::framework::ProgramDesc &legacy_program) {
......
...@@ -28,6 +28,7 @@ from paddle.fluid.libpaddle.ir import ( ...@@ -28,6 +28,7 @@ from paddle.fluid.libpaddle.ir import (
reset_insertion_point_to_start, reset_insertion_point_to_start,
reset_insertion_point_to_end, reset_insertion_point_to_end,
check_unregistered_ops, check_unregistered_ops,
register_paddle_dialect,
PassManager, PassManager,
) # noqa: F401 ) # noqa: F401
......
...@@ -12,13 +12,13 @@ ...@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle import paddle
def _switch_to_new_ir(): def _switch_to_new_ir():
if paddle.ir.core._use_new_ir_api(): if paddle.ir.core._use_new_ir_api():
paddle.framework.set_flags({"FLAGS_enable_new_ir_in_executor": True}) paddle.framework.set_flags({"FLAGS_enable_new_ir_in_executor": True})
paddle.ir.register_paddle_dialect()
paddle.static.Program = paddle.ir.Program paddle.static.Program = paddle.ir.Program
paddle.fluid.Program = paddle.ir.Program paddle.fluid.Program = paddle.ir.Program
paddle.fluid.program_guard = paddle.ir.core.program_guard paddle.fluid.program_guard = paddle.ir.core.program_guard
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册