未验证 提交 6216beb3 编写于 作者: R ronnywang 提交者: GitHub

[CustomPass] add register_pass api (#55511)

上级 3f17596a
......@@ -74,6 +74,16 @@ static const std::vector<std::string> xpu_support_subgraph_passes = {
"xpu_delete_cast_op_pass",
};
static std::vector<std::string> support_subgraph_generate_passes;
void Pass::AddSupportSubgraphPass(const std::string &pass_type) {
if (std::find(support_subgraph_generate_passes.begin(),
support_subgraph_generate_passes.end(),
pass_type) == support_subgraph_generate_passes.end()) {
support_subgraph_generate_passes.push_back(pass_type);
}
}
Graph *Pass::Apply(Graph *graph) const {
VLOG(10) << "start to apply pass " << Type() << " to graph";
CheckPrevPass();
......@@ -117,7 +127,10 @@ Graph *Pass::Apply(Graph *graph) const {
subgraph_passes = support_subgraph_passes;
}
if (graph->IsMainGraph() &&
std::count(subgraph_passes.begin(), subgraph_passes.end(), Type())) {
(std::count(subgraph_passes.begin(), subgraph_passes.end(), Type()) ||
std::count(support_subgraph_generate_passes.begin(),
support_subgraph_generate_passes.end(),
Type()))) {
for (size_t i = 1; i < graph->SubGraphsSize(); i++) {
auto *sub_graph = graph->GetSubGraph(i);
if (!sub_graph->Has(framework::ir::kParamScopeAttr)) {
......
......@@ -168,6 +168,8 @@ class Pass {
virtual bool SupportApplyProgramViaGraph() const { return true; }
static void AddSupportSubgraphPass(const std::string &pass_type);
protected:
virtual void ApplyImpl(Graph *graph UNUSED) const {
PADDLE_THROW(platform::errors::Unimplemented(
......
......@@ -2319,6 +2319,9 @@ All parameter, weight, gradient are variables in Paddle.
auto pass = framework::ir::PassRegistry::Instance().Get(pass_type);
return std::shared_ptr<framework::ir::Pass>(std::move(pass));
});
m.def("register_subgraph_pass", [](const std::string &pass_type) {
framework::ir::Pass::AddSupportSubgraphPass(pass_type);
});
m.def("size_of_dtype", framework::SizeOfType);
py::class_<paddle::platform::ProfilerResult>(m, "_ProfilerResult")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册