diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index 5c23e826a2decffdaf9c138cd9bd50e098aa9286..a8147fd466b5216b197f9f275ea3e7abf6ff99f9 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -903,8 +903,6 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, float, softmax_qk_op_desc->GetAttr("out_threshold")); multihead_op_desc.SetAttr("dp_probs", qkv_plugin_scale); } - } else { - multihead_op_desc.SetAttr("qkv2context_plugin_int8", false); } auto* multihead = graph->CreateOpNode(&multihead_op_desc); diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index 2a9b015ce982cb0f6982058a4e9280f951f45d62..a073acc96c0d4f27e32ccf61dfe0b1414973e7cc 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -42,8 +42,7 @@ class MultiheadMatMulOpConverter : public OpConverter { float* weight_data = nullptr; bool enable_int8 = op_desc.HasAttr("enable_int8"); - bool qkv2context_plugin_int8 = - BOOST_GET_CONST(bool, op_desc.GetAttr("qkv2context_plugin_int8")); + bool qkv2context_plugin_int8 = op_desc.HasAttr("qkv2context_plugin_int8"); float in_scale = 0.; if (enable_int8) {