From a863cf73165e6aa9351ce7766462e1f71c11559e Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Tue, 29 Jun 2021 18:51:19 +0800 Subject: [PATCH] unsqueeze2_eltwise_fuse_pass_init (#33808) --- .../ir/unsqueeze2_eltwise_fuse_pass.cc | 49 ++++++++++++++++++- .../ir/unsqueeze2_eltwise_fuse_pass.h | 1 + 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.cc b/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.cc index dc97e8c023..d53431d260 100644 --- a/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.cc +++ b/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.cc @@ -73,6 +73,46 @@ PDNode *UnsqueezeEltwise::operator()(PDNode *x, PDNode *y) { } // namespace patterns +UnsqueezeEltwiseFusePass::UnsqueezeEltwiseFusePass() { + AddOpCompat(OpCompat("unsqueeze2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("AxesTensor") + .IsOptional() + .IsTensor() + .End() + .AddInput("AxesTensorList") + .IsOptional() + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axes") + .IsType>() + .End(); + + AddOpCompat(OpCompat("elementwise_mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + // The attribute value is - 1 before fusion and 0 after fusion + .AddAttr("axis") + .IsIntIn({-1, 0}) + .End(); +} + void UnsqueezeEltwiseFusePass::ApplyImpl(ir::Graph *graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); @@ -100,7 +140,10 @@ void UnsqueezeEltwiseFusePass::ApplyImpl(ir::Graph *graph) const { LOG(WARNING) << "The subgraph is empty."; return; } - + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } VLOG(4) << "handle UnsqueezeEltwise fuse"; GET_IR_NODE_FROM_SUBGRAPH(eltwise_op, elementwise, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, fused_pattern); @@ -123,6 +166,10 @@ void UnsqueezeEltwiseFusePass::ApplyImpl(ir::Graph *graph) const { IR_NODE_LINK_TO(eltwise_op, eltwise_out); GraphSafeRemoveNodes(graph, {unsqz_op, unsqz_out}); found_subgraph_count++; + if (!IsCompat(*eltwise_op->Op())) { + LOG(WARNING) << "unsqueeze2_eltwise_fuse_pass op compat failed."; + return; + } } }; diff --git a/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h b/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h index 3be29f0e02..0410e5b3f3 100644 --- a/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h +++ b/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h @@ -34,6 +34,7 @@ class Graph; // it maybe change in runtime. class UnsqueezeEltwiseFusePass : public FusePassBase { public: + UnsqueezeEltwiseFusePass(); virtual ~UnsqueezeEltwiseFusePass() {} protected: -- GitLab