diff --git a/paddle/fluid/framework/ir/vit_attention_fuse_pass.cc b/paddle/fluid/framework/ir/vit_attention_fuse_pass.cc index 6591e1bbf12d06c360f76057420de2c08d83e03c..829c904dcfa1956c2c874cb0ac458e1f17b25674 100644 --- a/paddle/fluid/framework/ir/vit_attention_fuse_pass.cc +++ b/paddle/fluid/framework/ir/vit_attention_fuse_pass.cc @@ -79,6 +79,10 @@ void VitAttentionFusePass::ApplyImpl(ir::Graph* graph) const { OpDesc desc(matmul0_op->Op()->Block()); desc.SetType("multihead_matmul"); desc.SetInput("Input", {subgraph.at(x)->Name()}); + if (matmul0_out->Var()->GetShape().size() != 3) { + VLOG(3) << "vit_attention_fuse_pass only support input.dim == 3"; + return; + } // refactor W and Bias auto* w_tensor = scope->FindVar(matmul0_in_y->Name())->GetMutable();