diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index f11f52a0b1cdaaea3673b25828f8e4e7d2f3cf18..f1cd1face3444e4bf88239640e09c1df5df44576 100755 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -74,6 +74,16 @@ static const std::vector xpu_support_subgraph_passes = { "xpu_delete_cast_op_pass", }; +static std::vector support_subgraph_generate_passes; + +void Pass::AddSupportSubgraphPass(const std::string &pass_type) { + if (std::find(support_subgraph_generate_passes.begin(), + support_subgraph_generate_passes.end(), + pass_type) == support_subgraph_generate_passes.end()) { + support_subgraph_generate_passes.push_back(pass_type); + } +} + Graph *Pass::Apply(Graph *graph) const { VLOG(10) << "start to apply pass " << Type() << " to graph"; CheckPrevPass(); @@ -117,7 +127,10 @@ Graph *Pass::Apply(Graph *graph) const { subgraph_passes = support_subgraph_passes; } if (graph->IsMainGraph() && - std::count(subgraph_passes.begin(), subgraph_passes.end(), Type())) { + (std::count(subgraph_passes.begin(), subgraph_passes.end(), Type()) || + std::count(support_subgraph_generate_passes.begin(), + support_subgraph_generate_passes.end(), + Type()))) { for (size_t i = 1; i < graph->SubGraphsSize(); i++) { auto *sub_graph = graph->GetSubGraph(i); if (!sub_graph->Has(framework::ir::kParamScopeAttr)) { diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 1f59466e1cd802e18135d9855ced50d21433e907..473890a4b786ba4ba1ef59894c521fd16d100929 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -168,6 +168,8 @@ class Pass { virtual bool SupportApplyProgramViaGraph() const { return true; } + static void AddSupportSubgraphPass(const std::string &pass_type); + protected: virtual void ApplyImpl(Graph *graph UNUSED) const { PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c4036944bc18ae953a5c922fb6fdb19f9c2dda79..d55cab98b1eba60b1fe8f3740332d6e5fe36ec95 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2319,6 +2319,9 @@ All parameter, weight, gradient are variables in Paddle. auto pass = framework::ir::PassRegistry::Instance().Get(pass_type); return std::shared_ptr(std::move(pass)); }); + m.def("register_subgraph_pass", [](const std::string &pass_type) { + framework::ir::Pass::AddSupportSubgraphPass(pass_type); + }); m.def("size_of_dtype", framework::SizeOfType); py::class_(m, "_ProfilerResult")