未验证 提交 b9d66e6b 编写于 作者: F feng_shuai 提交者: GitHub

fix bugs for vit attention pass (#45721)

* fix: vit attention pass

* reflash CI
上级 4a9895b1
...@@ -79,6 +79,10 @@ void VitAttentionFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -79,6 +79,10 @@ void VitAttentionFusePass::ApplyImpl(ir::Graph* graph) const {
OpDesc desc(matmul0_op->Op()->Block()); OpDesc desc(matmul0_op->Op()->Block());
desc.SetType("multihead_matmul"); desc.SetType("multihead_matmul");
desc.SetInput("Input", {subgraph.at(x)->Name()}); desc.SetInput("Input", {subgraph.at(x)->Name()});
if (matmul0_out->Var()->GetShape().size() != 3) {
VLOG(3) << "vit_attention_fuse_pass only support input.dim == 3";
return;
}
// refactor W and Bias // refactor W and Bias
auto* w_tensor = auto* w_tensor =
scope->FindVar(matmul0_in_y->Name())->GetMutable<LoDTensor>(); scope->FindVar(matmul0_in_y->Name())->GetMutable<LoDTensor>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册