diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index a8595d55b31b05fb6f9a2f9ff5ff7a8787678100..f2e0e9613fc44ccc0c08f7e927196cfa1ee2e5c9 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -863,8 +863,8 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, auto* mul0_op_desc = mul0->Op(); - // all mul op has same input. - if (multihead_op_desc.HasAttr("Input_scale")) { + // all mul op has same input. Set int8 attr: Input_scale + if (mul0_op_desc->HasAttr("Input_scale")) { multihead_op_desc.SetAttr("Input_scale", mul0_op_desc->GetAttr("Input_scale")); }