diff --git a/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.cc b/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.cc index dc97e8c0233a60cfe789e33e63782d94ced907e9..d53431d260eaffd07ea8141b40a58b5df000ac63 100644 --- a/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.cc +++ b/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.cc @@ -73,6 +73,46 @@ PDNode *UnsqueezeEltwise::operator()(PDNode *x, PDNode *y) { } // namespace patterns +UnsqueezeEltwiseFusePass::UnsqueezeEltwiseFusePass() { + AddOpCompat(OpCompat("unsqueeze2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("AxesTensor") + .IsOptional() + .IsTensor() + .End() + .AddInput("AxesTensorList") + .IsOptional() + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axes") + .IsType>() + .End(); + + AddOpCompat(OpCompat("elementwise_mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + // The attribute value is - 1 before fusion and 0 after fusion + .AddAttr("axis") + .IsIntIn({-1, 0}) + .End(); +} + void UnsqueezeEltwiseFusePass::ApplyImpl(ir::Graph *graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); @@ -100,7 +140,10 @@ void UnsqueezeEltwiseFusePass::ApplyImpl(ir::Graph *graph) const { LOG(WARNING) << "The subgraph is empty."; return; } - + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } VLOG(4) << "handle UnsqueezeEltwise fuse"; GET_IR_NODE_FROM_SUBGRAPH(eltwise_op, elementwise, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, fused_pattern); @@ -123,6 +166,10 @@ void UnsqueezeEltwiseFusePass::ApplyImpl(ir::Graph *graph) const { IR_NODE_LINK_TO(eltwise_op, eltwise_out); GraphSafeRemoveNodes(graph, {unsqz_op, unsqz_out}); found_subgraph_count++; + if (!IsCompat(*eltwise_op->Op())) { + LOG(WARNING) << "unsqueeze2_eltwise_fuse_pass op compat failed."; + return; + } } }; diff --git a/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h b/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h index 3be29f0e0288855e3f7e940c527f80b66edccca9..0410e5b3f330cdf4f20df6b9b17e661e1a699b6c 100644 --- a/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h +++ b/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h @@ -34,6 +34,7 @@ class Graph; // it maybe change in runtime. class UnsqueezeEltwiseFusePass : public FusePassBase { public: + UnsqueezeEltwiseFusePass(); virtual ~UnsqueezeEltwiseFusePass() {} protected: