From b9d66e6ba4112c3a747aace936e4b91761c8e237 Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Mon, 5 Sep 2022 10:52:21 +0800 Subject: [PATCH] fix bugs for vit attention pass (#45721) * fix: vit attention pass * reflash CI --- paddle/fluid/framework/ir/vit_attention_fuse_pass.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddle/fluid/framework/ir/vit_attention_fuse_pass.cc b/paddle/fluid/framework/ir/vit_attention_fuse_pass.cc index 6591e1bbf12..829c904dcfa 100644 --- a/paddle/fluid/framework/ir/vit_attention_fuse_pass.cc +++ b/paddle/fluid/framework/ir/vit_attention_fuse_pass.cc @@ -79,6 +79,10 @@ void VitAttentionFusePass::ApplyImpl(ir::Graph* graph) const { OpDesc desc(matmul0_op->Op()->Block()); desc.SetType("multihead_matmul"); desc.SetInput("Input", {subgraph.at(x)->Name()}); + if (matmul0_out->Var()->GetShape().size() != 3) { + VLOG(3) << "vit_attention_fuse_pass only support input.dim == 3"; + return; + } // refactor W and Bias auto* w_tensor = scope->FindVar(matmul0_in_y->Name())->GetMutable(); -- GitLab